Flow Matching: A visual introduction
Flow Matching (FM) has become a prevalent technique to train a certain class of generative models. In this post we'll try to explore the intuition behind flow matching and how it works.
We'll use this notebook to build a simple flow matching model illustrating linear flow matching based on a minimal toy example. Our goal is to try to keep things simple, intuitive, and visual. We won't be doing any deep dive into the mathematical details of the model, if you're interested in the mathematical details I recommend checking out the references at the end of this post.
# Imports and setup
import base64
import functools
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from IPython.display import HTML
from matplotlib.animation import FuncAnimation
from tqdm import tqdm
sns.set_style("darkgrid") # Set the style of the plots
pd.options.display.float_format = "{:,.3f}".format # Table display format
# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(626)
# PyTorch Device configuration
DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
Flow matching
Flow matching is a technique to learn how to transport samples from one distribution to another. For example we could learn how to transport samples from a simple distribution we can easily sample from (e.g. Gaussian noise) to a complex distribution (e.g. images , videos , robot actions , etc.).
Toy Example: Mapping Gaussian noise to a bimodal distribution
In this post we'll build a simple toy example of a generative model using flow matching. For illustrative purposes we'll start with a simple 1D bimodal target distribution $π_1$ and learn how to transport samples from a 1D Gaussian noise distribution $π_0$ to this target distribution.
In practice the target points $x_1 \sim π_1$ are approximated by sampling from a limited dataset of training points $X_1$ and the noise points $x_0 \sim π_0$ are sampled from a chosen noise distribution $π_0$ that is easy to sample from (e.g. Gaussian noise).
# Define 1D bimodal target distribution
mixture_prob = np.array([0.55, 0.45], dtype=float) # Mixture weights
mixture_mus = np.array([-0.85, 1.5], dtype=float) # Means of the two Gaussian modes
mixture_sigmas = np.array([0.65, 0.25], dtype=float) # Standard deviations of the modes
def mixture_pdf(x: np.ndarray) -> np.ndarray:
"""Compute the PDF of a mixture of Gaussians."""
comps = scipy.stats.norm.pdf(x[None, :], loc=mixture_mus[:, None], scale=mixture_sigmas[:, None])
return np.sum(mixture_prob[:, None] * comps, axis=0)
def mixture_sample(size: int) -> np.ndarray:
"""Sample from a mixture of Gaussians."""
rand_idx = np.random.choice(range(len(mixture_prob)), size=size, p=mixture_prob)
means = mixture_mus[rand_idx]
stds = mixture_sigmas[rand_idx]
return np.random.normal(loc=means, scale=stds)
# Plot data distribution. This is the TARGET distribution (π₁)
fig, ax = plt.subplots(1, 1, figsize=(8, 3), constrained_layout=True, dpi=100)
x_all_steps = np.linspace(-3, 3, 1000)
pdf_noise = scipy.stats.norm.pdf(x_all_steps, loc=0, scale=1)
ax.plot(x_all_steps, pdf_noise, label="PDF Noise π₀", color="tab:orange")
ax.fill_between(x_all_steps, pdf_noise, alpha=0.4, color="tab:orange")
pdf_target = mixture_pdf(x=x_all_steps)
ax.plot(x_all_steps, pdf_target, label="PDF Target π₁", color="tab:blue")
ax.fill_between(x_all_steps, pdf_target, alpha=0.4, color="tab:blue")
ax.legend()
ax.set_title("Toy Example Data: Noise (π₀) vs Target (π₁) Distributions")
ax.set_xlabel("x")
ax.set_ylabel("density")
plt.show()
del fig, ax, x_all_steps, pdf_noise, pdf_target
#
The flow matching model predicts a velocity field
A flow matching model does not predict flow paths directly, but instead predicts a velocity field that can be used to sample the flow paths. The velocity field describes how to move a sample from the noise distribution to the target distribution.
We can describe the flow matching model with learnable parameters $\theta$ as a function: $${FM}_{\theta}(x_t, t) = v(x_t, t)$$ This function takes a sample $x_t$ at flow step $t$ and predicts the velocity vector $v(x_t, t) = dx_t / dt$ that describes how to move the sample $x_t$ closer to the target distribution at step $t$.
The step $t$ is a value between 0 and 1 that describes the progress of the sample $x_t$ along the flow path from the noise distribution to the target distribution. When $t=0$ the sample $x_t = x_0$ is a sample from the noise distribution $π_0$ and when $t=1$ the sample $x_t = x_1$ is a sample from the target distribution $π_1$.
At inference time we can sample a starting point $x_0$ from the noise distribution $π_0$ and then use the predicted velocity field ${FM}_{\theta}(x_t, t)$ to iteratively move the sample towards the target distribution $π_1$ in small steps $dt$
This is illustrated in the following animation ( generated further down in the notebook ) which shows the integration of a sample from the noise distribution $π_0$ on the left towards the target distribution $π_1$ on the right using the predicted velocity field ${FM}_{\theta}(x_t, t)$. The velocity field is visualized as a heatmap where the vertical axis represents the position of the sample $x_t$ and the horizontal axis represents the flow step $t$ going from 0 on the left to 1 on the right. Red means a positive velocity (sample pushed up towards higher $x$) and blue means a negative velocity (sample pulled down towards lower $x$).
# Display the animation in the notebook
# This animation is generated further down in the notebook, if it doesn't exist yet we'll skip the display
# Embed the GIF directly in the notebook by encoding the bytes as base64, this way it should hopefully also be exported
ANIMATION_FILE = Path("flow_matching_path_integration.mp4")
if ANIMATION_FILE.exists():
with ANIMATION_FILE.open("rb") as f:
gif_data = f.read()
display(
HTML(f"""
<video alt="Flow matching path integration" loop="true" autoplay="autoplay" muted>
<source type="video/mp4" src="data:video/mp4;base64,{base64.b64encode(gif_data).decode()}">
</video>
""")
)
else:
print(
"Animation file not yet created, it is generated further down in the notebook. Run the full notebook to generate it."
)
#
Training the flow matching model is learning the velocity field
Since the flow matching model ${FM}_{\theta}(x_t, t)$ should predict the velocity field $v(x_t, t) = dx_t / dt$ we can train the model on samples of velocity vectors $\mathbf{v}(x_t, t)$.
The flow matching training objective is to minimize the expected reconstruction error of the velocity field: $$ \underset{\theta}{\text{argmin}} \; \mathbb{E}_{t, x_t} \Big\| {FM}_{\theta}(x_t, t) - v(x_t, t) \Big\|^2 $$
with $t \sim \mathcal{U}[0, 1]$ and $x_t$ taken from a sampled reference path evaluated at flow step $t$.
We'll be using straight line reference paths in this post since they are simple and common.
Training: Straight line reference paths
We're going to focus on a common variant of flow matching where we learn a flow matching model based on straight line reference paths. Training flow matching with straight-line conditional paths and independent couplings is also equivalent to the rectified flow training objective.
Linear (straight line) flow matching is trained on a set of reference paths between the noise and target distributions. More specifically, linear flow matching prefers learning from straight line trajectories between the noise and target distributions because they tend to give straighter paths that require fewer steps to reconstruct the target distribution.
To sample a reference path we can independently sample a target point $x_1$ from our target distribution $π_1$ and independently sample a noise point $x_0$ from the noise distribution $π_0$. This gives us a single coupling $(x_0, x_1)$ that allows us to define a straight line reference path between the noise and target samples. During training we'll sample a large set of coupling-inducing paths $(X_0, X_1)$ and use these to train the flow matching model.
The following code illustrates how we define the straight line reference path between a noise and target sample.
def interpolate_linear(x_0, x_1, t):
"""Evaluates the linear interpolation path between x_0 and x_1 at step t."""
x_t = ((1 - t) * x_0) + (t * x_1)
return x_t
The following figure shows a few sampled straight-line reference paths, as well as the reference path distribution approximated by sampling a large number of straight-line reference paths.
# Illustration of the sampled reference paths
# Set up the plot
fig, ((ax11, ax12, ax13), (ax21, ax22, ax23)) = plt.subplots(
2,
3,
figsize=(12, 8),
gridspec_kw={"width_ratios": [1, 5, 1]},
sharey=True,
dpi=100,
)
fig.subplots_adjust(wspace=0)
x_min, x_max = -2.5, 2.5
x_all_steps = np.linspace(x_min, x_max, 1000)
# Sample set of noise and target points
data_size: int = 100_000
np.random.seed(1) # Set random seeds for reproducibility
data_x_0 = np.random.randn(data_size)
data_x_1 = mixture_sample(size=data_size)
# Plot a few sample paths ##########################################
# Plot Noise distribution π₀
ax11.set_title("Noise Distribution π₀")
pdf_noise = scipy.stats.norm.pdf(x_all_steps, loc=0, scale=1)
ax11.plot(pdf_noise, x_all_steps, label="PDF Noise π₀", color="tab:orange")
ax11.fill_betweenx(x_all_steps, pdf_noise, alpha=0.4, color="tab:orange")
ax11.invert_xaxis()
ax11.set_ylabel("x")
ax11.set_xlabel("")
# ax11.set_xlim(0, 1)
# Plot final distribution x1
ax13.set_title("Target Distribution π₁")
pdf_target = mixture_pdf(x=x_all_steps)
ax13.plot(pdf_target, x_all_steps, label="PDF Target π₁", color="tab:blue")
ax13.fill_betweenx(x_all_steps, pdf_target, alpha=0.4, color="tab:blue")
ax13.set_xlabel("")
ax13.yaxis.set_label_position("right")
ax13.set_ylabel("x")
# Also show y-axis values on the right side
ax13.yaxis.set_visible(True)
ax13.yaxis.set_tick_params(labelright=True)
# ax13.set_xlim(0, 1)
# Plot the Sample paths
nb_samples = 7
t = np.arange(0, 1, 0.01)
ax12.set_title("Sample of straight line reference paths")
colors = plt.cm.tab10.colors
for i in range(nb_samples):
color = colors[i % len(colors)]
ax12.plot(t, interpolate_linear(x_0=data_x_0[i], x_1=data_x_1[i], t=t), alpha=0.5, color=color)
ax12.scatter([0, 1], [data_x_0[i], data_x_1[i]], color=color)
ax12.set_ylim(x_min, x_max)
ax12.set_xlim(0, 1)
ax12.set_xlabel("$t$: flow step")
# Plot the full data distribution ##################################
# Plot Noise samples X0
ax21.set_title("Noise Samples X₀")
ax21.hist(
data_x_0.flatten(),
bins=100,
alpha=0.5,
label="Noise π₀",
color="tab:orange",
density=True,
orientation="horizontal",
)
ax21.invert_xaxis()
ax21.set_ylabel("x")
ax21.set_xlabel("density")
ax21.sharex(ax11)
# ax21.set_xlim(0, 1)
# Plot target data distribution x1
ax23.set_title("Target Data X₁")
ax23.hist(
data_x_1.flatten(), bins=100, alpha=0.5, label="Target π₁", color="tab:blue", density=True, orientation="horizontal"
)
ax23.set_xlabel("density")
ax23.yaxis.set_label_position("right")
ax23.set_ylabel("x")
# Also show y-axis values on the right side
ax23.yaxis.set_visible(True)
ax23.yaxis.set_tick_params(labelright=True)
ax23.sharex(ax13)
# ax23.set_xlim(0, 1)
# Plot path density
n_samples = int(data_x_1.shape[0])
dt: float = 0.01 # Step size for Euler integration
n_flow_steps = int(1 / dt)
# Set up the path density histogram parameters
img_hist_size = 480
path_density_bins = np.zeros((n_flow_steps + 1, img_hist_size))
flow_field_img_x_bin_edges = np.linspace(x_min, x_max, img_hist_size + 1)
# Add the histogram of the initial distribution
path_density_bins[0] = np.histogram(data_x_0, bins=flow_field_img_x_bin_edges)[0]
path_density_bins[-1] = np.histogram(data_x_1, bins=flow_field_img_x_bin_edges)[0]
# Build up the histogram of the reference paths by going over the discretized t-bins
for i in range(n_flow_steps):
t = np.full((n_samples,), i * dt)
x_t = interpolate_linear(x_0=data_x_0, x_1=data_x_1, t=t)
path_density_bins[i] = np.histogram(x_t, bins=flow_field_img_x_bin_edges)[0]
im = ax22.imshow(path_density_bins.T, aspect="auto", origin="lower", extent=[0, 1, x_min, x_max], cmap="viridis")
ax22.set_xlabel("$t$: flow step")
ax22.set_title("Reference path density between X₀ and X₁")
ax22.grid(False)
plt.tight_layout()
plt.show()
del (fig, ax11, ax12, ax13, ax21, ax22, ax23, x_all_steps, data_x_0, data_x_1, n_samples, dt, n_flow_steps, # fmt: skip
img_hist_size, path_density_bins, flow_field_img_x_bin_edges, i, t, x_t, im) # fmt: skip
#
Training: Sampling velocity vectors
Since we are using straight-line reference paths, the sampled velocity vectors $\mathbf{v}(x_t, t)$ have a very simple form. Given a sample from the noise distribution $x_0$ and a sample from the target distribution $x_1$ we can describe the conditional velocity vector along the straight-line connecting $x_0$ and $x_1$ as: $\mathbf{v}(x_t, t) = x_1 - x_0$ as illustrated in the following code and figure.
def get_target_velocity(x_0, x_1):
"""
Get the velocity for a given pair of noise and target points.
This is the per-pair (conditional) velocity along the straight path.
"""
return x_1 - x_0
# Illustrate the flow matching target velocity vector
# Set up the plot
fig, (ax1, ax2, ax3) = plt.subplots(
1,
3,
figsize=(12, 4),
gridspec_kw={"width_ratios": [1, 5, 1]},
sharey=True,
dpi=100,
)
fig.subplots_adjust(wspace=0)
x_min, x_max = -2.5, 2.5
x_all_steps = np.linspace(x_min, x_max, 1000)
# Plot a few sample paths ##########################################
# Plot Noise distribution π₀
ax1.set_title("Noise Distribution π₀")
pdf_noise = scipy.stats.norm.pdf(x_all_steps, loc=0, scale=1)
ax1.plot(pdf_noise, x_all_steps, label="PDF Noise π₀", color="tab:orange")
ax1.fill_betweenx(x_all_steps, pdf_noise, alpha=0.4, color="tab:orange")
ax1.invert_xaxis()
ax1.set_ylabel("x")
ax1.set_xlabel("")
# ax11.set_xlim(0, 1)
# Plot final distribution x1
ax3.set_title("Target Distribution π₁")
pdf_target = mixture_pdf(x=x_all_steps)
ax3.plot(pdf_target, x_all_steps, label="PDF Target π₁", color="tab:blue")
ax3.fill_betweenx(x_all_steps, pdf_target, alpha=0.4, color="tab:blue")
ax3.set_xlabel("")
ax3.yaxis.set_label_position("right")
ax3.set_ylabel("x")
# Also show y-axis values on the right side
ax3.yaxis.set_visible(True)
ax3.yaxis.set_tick_params(labelright=True)
# ax13.set_xlim(0, 1)
# Plot the Sample paths
x_0_examplar = 1.32
x_1_examplar = -1.92
t_examplar = 0.67
x_t_examplar = interpolate_linear(x_0=x_0_examplar, x_1=x_1_examplar, t=t_examplar)
# Annotate the path between x_0 and x_1
ax2.set_title("Flow matching target for a sample path: velocity $v(x_t, t) = x_1 - x_0$")
ax2.plot([0, 1], [x_0_examplar, x_1_examplar], alpha=0.8, color="tab:green", label="Path(x₀, x₁)")
ax2.scatter([0, 1], [x_0_examplar, x_1_examplar], color="tab:green")
# ax2.scatter([1], [x_0_examplar], color="tab:green", marker="_")
ax2.legend()
ax2.plot([0, 1], [x_0_examplar, x_0_examplar], alpha=0.5, color="tab:green", linestyle="dotted")
ax2.annotate(
"$x_0$",
xy=(0, x_0_examplar),
xycoords="data",
xytext=(-20, x_0_examplar - 3), # Shift the text to the left of the point
textcoords="offset points",
fontsize=16,
color="tab:green",
annotation_clip=False,
)
ax2.annotate(
" $x_0$",
xy=(1, x_0_examplar),
xycoords="data",
xytext=(1, x_0_examplar - 3),
textcoords="offset points",
fontsize=16,
color="tab:green",
annotation_clip=False,
)
ax2.annotate(
" $x_1$",
xy=(1, x_1_examplar),
xycoords="data",
xytext=(1, x_1_examplar - 3),
textcoords="offset points",
fontsize=16,
color="tab:green",
annotation_clip=False,
)
# Annotate x_t
ax2.annotate(
"$x_t$",
xy=(0, x_t_examplar),
xycoords="data",
xytext=(-20, x_t_examplar), # Shift the text to the left of the point
textcoords="offset points",
fontsize=16,
color="tab:gray",
annotation_clip=False,
)
ax2.plot([0, t_examplar], [x_t_examplar, x_t_examplar], linestyle=":", color="tab:gray")
# Annotate t
ax2.annotate(
"$t$",
xy=(t_examplar, x_min),
xycoords="data",
xytext=(t_examplar - 4, x_min - 12), # Shift the text below the point
textcoords="offset points",
fontsize=16,
color="tab:gray",
annotation_clip=False,
)
ax2.plot([t_examplar, t_examplar], [x_min, x_t_examplar], linestyle=":", color="tab:gray")
# Annotate the velocity vector
ax2.annotate(
"",
xy=(1, x_1_examplar),
xycoords="data",
xytext=(1, x_0_examplar),
arrowprops=dict(arrowstyle="->", color="tab:red", linewidth=2),
annotation_clip=False,
)
ax2.annotate(
" $x_1 - x_0$",
xy=(1, (x_0_examplar + x_1_examplar) / 2),
xycoords="data",
xytext=(1, (x_0_examplar + x_1_examplar) / 2),
textcoords="offset points",
fontsize=16,
color="tab:red",
annotation_clip=False,
)
ax2.scatter([t_examplar], [x_t_examplar], color="tab:red", marker="D", zorder=10)
ax2.annotate(
r" $\mathbf{v}(x_t, t) = x_1 - x_0$",
xy=(t_examplar, x_t_examplar),
xycoords="data",
xytext=(t_examplar, x_t_examplar),
textcoords="offset points",
fontsize=16,
color="tab:red",
annotation_clip=False,
)
ax2.set_ylim(x_min, x_max)
ax2.set_xlim(0, 1)
ax2.set_xlabel(r"$t$: flow step")
plt.tight_layout()
plt.show()
del (fig, ax1, ax2, ax3, x_min, x_max, x_all_steps, # fmt: skip
x_0_examplar, x_1_examplar, t_examplar, x_t_examplar) # fmt: skip
#
Training: Flow matching objective
We can now write out our objective as a function of the samples from the noise distribution $x_0$ and the target distribution $x_1$: $$ \underset{\theta}{\text{argmin}} \; \mathbb{E}_{t, X_0, X_1} \Big\| {FM}_{\theta}(x_t, t) - (X_1 - X_0) \Big\|^2 \quad\quad $$ with $t \sim \mathcal{U}[0, 1]$, $X_0 \sim \pi_0$, $X_1 \sim \pi_1$, and $x_t = (1 - t) X_0 + t X_1$.
Note that the flow matching model ${FM}_{\theta}(x_t, t)$ is trained conditionally on specific straight-line couplings $(X_0, X_1)$, but since these are averaged out in the training objective, the flow matching model will learn an approximation of the velocity field independent of any specific coupling.
For this simple toy example we could even approximate the flow field directly by sampling a large number of reference paths and computing the average velocity for fixed bins over the flow field. This approximated expectation is illustrated in the following figure, which shows the average flow field. Red means a positive velocity (sample pushed up towards higher $x$) and blue means a negative velocity (sample pulled down towards lower $x$).
# Flow field illustration, approximate the flow field by sampling a large number of reference paths over discretized flow step bins.
# Sample set of noise and target points
data_size: int = 100_000
np.random.seed(1) # Set random seeds for reproducibility
data_x_0 = np.random.randn(data_size)
data_x_1 = mixture_sample(size=data_size)
# Set up the plot
fig, (ax1, ax2, ax3) = plt.subplots(
1,
3,
figsize=(12, 5),
gridspec_kw={"width_ratios": [1, 5, 1]},
sharey=True,
dpi=100,
constrained_layout=True,
)
x_min, x_max = -2.2, 2.2 # Narrow view for flow field image
x_min_wide, x_max_wide = -3.2, 3.2 # Wide view for sample paths
# Plot the velocity field ##########################################
# Plot Noise samples X0
ax1.set_title("Noise samples X₀")
ax1.hist(
data_x_0.flatten(),
bins=100,
alpha=0.5,
label="Noise π₀",
color="tab:orange",
density=True,
orientation="horizontal",
)
ax1.invert_xaxis()
ax1.set_ylabel("x")
ax1.set_xlabel("density")
# Plot target data distribution x1
ax3.set_title("Target data X₁")
ax3.hist(
data_x_1.flatten(), bins=100, alpha=0.5, label="Target π₁", color="tab:blue", density=True, orientation="horizontal"
)
ax3.set_xlabel("density")
ax3.yaxis.set_label_position("right")
ax3.set_ylabel("x")
# Also show y-axis values on the right side
ax3.yaxis.set_visible(True)
ax3.yaxis.set_tick_params(labelright=True)
# Plot path density
n_samples = int(data_x_1.shape[0])
dt: float = 0.05 # Step size for Euler integration
n_flow_steps = int(1 / dt)
# Set up the path density histogram parameters
img_hist_size = 200
# Narrow view for flow field image
flow_field_img_bins = np.zeros((n_flow_steps + 1, img_hist_size))
flow_field_img_x_bin_edges = np.linspace(x_min, x_max, img_hist_size + 1)
# Wide view to compute the pathlines, need to be wide enough because paths can flow out of the narrow view
flow_field_for_path_bins = np.zeros((n_flow_steps + 1, img_hist_size))
flow_field_for_path_x_bin_edges = np.linspace(x_min_wide, x_max_wide, img_hist_size + 1)
# Build up the histogram of the reference paths by going over the discretized t-bins
# We're building 2 histograms:
# - `flow_field_img_bins` narrow view for the flow field image, might have NaNs if there are no samples in a bin
# - `flow_field_for_path_bins` wide view for the pathlines, avoid NaNs by setting the velocity to 0 if there are no samples in a bin
v = get_target_velocity(x_0=data_x_0, x_1=data_x_1)
for i in range(n_flow_steps + 1):
t = np.full((n_samples,), i * dt)
x_t = interpolate_linear(x_0=data_x_0, x_1=data_x_1, t=t)
# Get the average velocity for each bin in the narrow view for the flow field image
x_t_bin_indices = np.digitize(x=x_t, bins=flow_field_img_x_bin_edges[1:-1], right=False)
counts = np.bincount(x_t_bin_indices, minlength=img_hist_size)
sums = np.bincount(x_t_bin_indices, weights=v, minlength=img_hist_size)
flow_field_img_bins[i] = np.divide(
sums,
counts,
out=np.full(img_hist_size, np.nan, dtype=float),
where=counts > 0,
)
# Get the average velocity for each bin in the wide view for the pathlines
x_t_bin_indices_wide = np.digitize(x=x_t, bins=flow_field_for_path_x_bin_edges[1:-1], right=False)
counts_wide = np.bincount(x_t_bin_indices_wide, minlength=img_hist_size)
sums_wide = np.bincount(x_t_bin_indices_wide, weights=v, minlength=img_hist_size)
flow_field_for_path_bins[i] = np.divide(
sums_wide,
counts_wide,
out=np.zeros(img_hist_size, dtype=float), # Avoid NaNs for any path sampling
where=counts_wide > 0,
)
# Plot the flow field
max_abs_flow_field = np.nanmax(np.abs(flow_field_img_bins))
color_norm = matplotlib.colors.TwoSlopeNorm(vmin=-max_abs_flow_field, vcenter=0.0, vmax=max_abs_flow_field)
im = ax2.imshow(
flow_field_img_bins.T,
aspect="auto",
origin="lower",
extent=[0, 1, x_min, x_max],
cmap="coolwarm",
norm=color_norm,
)
ax2.set_ylim(x_min, x_max)
cbar = fig.colorbar(
im,
ax=[ax1, ax2, ax3],
orientation="horizontal",
fraction=0.08,
aspect=40,
pad=0.04,
)
cbar.set_label("Velocity field value (red pushes up, blue pulls down)")
# Sample some paths from the mean flow field to show the pathlines
n_paths = 9
paths = np.zeros((n_flow_steps + 1, n_paths))
paths[0] = np.linspace(start=x_min_wide, stop=x_max_wide, num=n_paths)
for i in range(n_flow_steps):
t = np.full((n_samples,), i * dt)
x_tm1 = paths[i]
x_tm1_bin_indices = np.digitize(x=x_tm1, bins=flow_field_for_path_x_bin_edges[1:-1], right=False)
v = flow_field_for_path_bins[i, x_tm1_bin_indices]
paths[i + 1] = x_tm1 + v * dt
# Plot the pathlines with arrows using quiver, following best practices
t_coords = np.linspace(0, 1, n_flow_steps + 1)
arrow_stride = 7 # space out arrows for clarity
for i in range(n_paths):
y = paths[:, i]
idx = np.arange(0, len(t_coords) - 1, arrow_stride)
x0 = t_coords[idx]
y0 = y[idx]
u = np.diff(t_coords)[idx]
v = np.diff(y)[idx]
ax2.quiver(
x0,
y0,
u,
v,
angles="xy",
scale_units="xy",
scale=0.6,
units="inches",
width=0.015,
headwidth=6,
headlength=9,
headaxislength=7,
pivot="tail",
color="dimgray",
alpha=0.9,
zorder=2,
)
# overlay original line on top of arrows for clarity
ax2.plot(
t_coords,
y,
linestyle="-",
color="dimgray",
linewidth=1.5,
alpha=0.7,
zorder=3,
label=f"Sample path {i + 1}",
)
ax2.set_xlabel("$t$: flow step")
ax2.set_title("Average velocity field with pathlines")
ax2.grid(False)
plt.show()
del (fig, ax1, ax2, ax3, x_min, x_max, x_min_wide, x_max_wide, data_size, # fmt: skip
data_x_0, data_x_1, n_samples, dt, n_flow_steps, img_hist_size, # fmt: skip
flow_field_img_bins, flow_field_img_x_bin_edges, flow_field_for_path_bins, # fmt: skip
flow_field_for_path_x_bin_edges, v, t, x_t, x_t_bin_indices, counts, sums, # fmt: skip
x_t_bin_indices_wide, counts_wide, sums_wide) # fmt: skip
#
Training the Flow Matching model
Now that we have defined our optimization objective, and how we can sample the data to train the model, we can define the flow matching model and train it. We'll create a simple neural network with a single hidden layer that we can train to predict the velocity field.
class FlowMatchingModel(nn.Module):
"""
Flow Matching model to predict the velocity field at time t and position x_t.
"""
def __init__(self, data_dim: int, hidden_dim: int) -> None:
super().__init__()
# Simple MLP
self.net: nn.Sequential = nn.Sequential(
nn.Linear(data_dim + 1, hidden_dim), # +1 for time embedding
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, data_dim),
)
def forward(
self,
t: torch.Tensor, # Denoising step [batch_size, 1]
x_t: torch.Tensor, # Interpolated samples [batch_size, data_dim]
) -> torch.Tensor: # [batch_size, data_dim]
"""
Predicts the velocity field at time t and position x_t.
"""
tx: torch.Tensor = torch.cat([t, x_t], dim=-1)
return self.net(tx)
We can now define the loss function as a function of the flow matching model, the noise samples $X_0$, the target samples $X_1$, and the flow steps $T$:
def compute_loss(
flow_matching_model: FlowMatchingModel,
x_0: torch.Tensor,
x_1: torch.Tensor,
t: torch.Tensor,
) -> torch.Tensor:
"""
Compute the loss for a single batch of (X_0, X_1) couplings and flow steps T.
"""
# Interpolate the data at the sampled time step
x_t = interpolate_linear(x_0=x_0, x_1=x_1, t=t)
# Get the target velocity
v_target = get_target_velocity(x_0=x_0, x_1=x_1)
# Predict the velocity
v_pred = flow_matching_model(t=t, x_t=x_t)
# Compute the loss
loss = ((v_pred - v_target) ** 2).mean()
return loss
Using this loss function we can now train the flow matching model in a straightforward gradient-based optimization loop. We'll use a standard Adam optimizer to optimize the model parameters.
# Train the flow matching model
# Hyperparameters
data_dim: int = 1 # 1D data
hidden_dim: int = 64
nb_train_iterations: int = 10_000
lr: float = 1e-3
batch_size: int = 256
# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(626)
# Initialize the vector field network and optimizer
flow_matching_model = FlowMatchingModel(data_dim=data_dim, hidden_dim=hidden_dim).to(DEVICE).train()
optimizer = optim.Adam(flow_matching_model.parameters(), lr=lr)
# Training loop
losses: list[float] = []
with tqdm(range(nb_train_iterations), desc="Training", unit="iteration") as progress_bar:
for i in progress_bar:
# Sample a batch of target and noise samples
x_1 = torch.from_numpy(mixture_sample(size=batch_size)).to(dtype=torch.float32, device=DEVICE).unsqueeze(-1)
x_0 = torch.randn_like(x_1)
# Sample a random time step for each sample in the batch
t = torch.rand(x_1.shape[0], device=DEVICE).unsqueeze(-1)
# Compute the loss
loss = compute_loss(flow_matching_model=flow_matching_model, x_0=x_0, x_1=x_1, t=t)
# Backpropagate the loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
progress_bar.set_postfix({"Loss": f"{loss.item():.2f}"})
# Plot loss curve after training
fig, ax = plt.subplots(figsize=(12, 3), dpi=100)
ax.plot(losses, color="tab:blue", alpha=0.5, label="Loss")
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
ax.set_title("Training Loss Curve")
# Plot a smoothed loss curve using a simple moving average
window_size = 100
smoothed_losses = np.convolve(losses, np.ones(window_size) / window_size, mode="valid")
ax.plot(np.arange(window_size - 1, len(losses)), smoothed_losses, color="tab:blue", label="Loss (moving avg)")
ax.legend(loc="upper right")
ax.set_xlim(0, len(losses))
ax.grid(True)
plt.show()
del fig, ax, window_size, smoothed_losses
#
Visualizing the trained flow matching model
Now that we have trained this simple flow matching model we can visualize the learned velocity field by getting the predicted velocity field ${FM}_{\theta}(x_t, t)$ at a grid of points $(t, x_t)$ and plotting this grid of velocities as a color image. Red means a positive velocity (sample pushed up towards higher $x$) and blue means a negative velocity (sample pulled down towards lower $x$).
# Flow field visualization of trained model
# Set up the plot
fig, (ax1, ax2, ax3) = plt.subplots(
1,
3,
figsize=(12, 5),
gridspec_kw={"width_ratios": [1, 5, 1]},
sharey=True,
dpi=100,
constrained_layout=True,
)
x_min, x_max = -2.2, 2.2 # Narrow view for flow field image
x_min_wide, x_max_wide = -3.2, 3.2 # Wide view for sample paths
# Sample set of noise and target points
data_size: int = 100_000
np.random.seed(1) # Set random seeds for reproducibility
data_x_0 = np.random.randn(data_size)
# Plot the velocity field ##########################################
# Plot Noise samples X0
ax1.set_title("Noise Samples X₀")
ax1.hist(
data_x_0.flatten(),
bins=100,
alpha=0.5,
label="Noise Samples X₀",
color="tab:orange",
density=True,
orientation="horizontal",
)
ax1.invert_xaxis()
ax1.set_ylabel("x")
ax1.set_xlabel("density")
# Plot Flow Field
n_flow_steps = 100
# Set up the flow field histogram parameters
img_hist_size = 200
# Narrow view for flow field image
flow_field_img_bins = np.zeros((n_flow_steps + 1, img_hist_size))
flow_field_img_x_bin_edges = np.linspace(x_min, x_max, img_hist_size + 1)
flow_field_img_x_bin_centers = (
torch.from_numpy((flow_field_img_x_bin_edges[:-1] + flow_field_img_x_bin_edges[1:]) / 2)
.float()
.to(DEVICE)
.unsqueeze(-1)
)
# Wide view to compute the pathlines, need to be wide enough because paths can flow out of the narrow view
flow_field_for_path_bins = np.zeros((n_flow_steps + 1, img_hist_size))
flow_field_for_path_x_bin_edges = np.linspace(x_min_wide, x_max_wide, img_hist_size + 1)
flow_field_for_path_x_bin_centers = (
torch.from_numpy((flow_field_for_path_x_bin_edges[:-1] + flow_field_for_path_x_bin_edges[1:]) / 2)
.float()
.to(DEVICE)
.unsqueeze(-1)
)
# Build up the histogram of the reference paths by going over the discretized t-bins
# We're building 2 histograms:
# - `flow_field_img_bins` narrow view for the flow field image, might have NaNs if there are no samples in a bin
# - `flow_field_for_path_bins` wide view for the pathlines, avoid NaNs by setting the velocity to 0 if there are no samples in a bin
with torch.inference_mode():
flow_matching_model.eval()
for i, t in enumerate(torch.linspace(0, 1, n_flow_steps + 1)):
t = t.expand_as(flow_field_img_x_bin_centers).to(DEVICE)
# Get the model's prediction for the velocity field at the bin centers
flow_field_img_bins[i] = flow_matching_model(t=t, x_t=flow_field_img_x_bin_centers).cpu().numpy().squeeze()
flow_field_for_path_bins[i] = (
flow_matching_model(t=t, x_t=flow_field_for_path_x_bin_centers).cpu().numpy().squeeze()
)
# Plot the flow field
max_abs_flow_field = np.nanmax(np.abs(flow_field_img_bins))
color_norm = matplotlib.colors.TwoSlopeNorm(vmin=-max_abs_flow_field, vcenter=0.0, vmax=max_abs_flow_field)
im = ax2.imshow(
flow_field_img_bins.T,
aspect="auto",
origin="lower",
extent=[0, 1, x_min, x_max],
cmap="coolwarm",
norm=color_norm,
)
ax2.set_ylim(x_min, x_max)
cbar = fig.colorbar(
im,
ax=[ax1, ax2, ax3],
orientation="horizontal",
fraction=0.08,
aspect=40,
pad=0.04,
)
cbar.set_label("Velocity field value (red pushes up, blue pulls down)")
# Create pathlines from the flow field
n_paths = 9
# Initialize tensor to store the flow process
paths = np.zeros((n_flow_steps + 1, n_paths))
x_t = torch.linspace(start=x_min_wide, end=x_max_wide, steps=n_paths, device=DEVICE).unsqueeze(-1)
paths[0] = x_t.cpu().squeeze().numpy()
# Generate the flow process
with torch.inference_mode():
flow_matching_model.eval()
ts = torch.linspace(0, 1, n_flow_steps + 1)
for i in range(n_flow_steps):
t = ts[i].expand_as(x_t).to(DEVICE)
dt = ts[i + 1] - ts[i]
x_t = x_t + flow_matching_model(t=t, x_t=x_t.to(DEVICE)) * dt
paths[i + 1] = x_t.cpu().numpy().squeeze()
# Plot the pathlines with arrows using quiver, following best practices
t_coords = np.linspace(0, 1, n_flow_steps + 1)
arrow_stride = 35 # space out arrows for clarity
arrow_offset = 7 # offset the arrows to the right for clarity
for i in range(n_paths):
y = paths[:, i]
idx = np.arange(arrow_offset, len(t_coords) - 1, arrow_stride)
x0 = t_coords[idx]
y0 = y[idx]
u = np.diff(t_coords)[idx]
v = np.diff(y)[idx]
ax2.quiver(
x0,
y0,
u,
v,
angles="xy",
scale_units="xy",
scale=0.6,
units="inches",
width=0.015,
headwidth=6,
headlength=9,
headaxislength=7,
pivot="tail",
color="dimgray",
alpha=0.9,
zorder=2,
)
# overlay original line on top of arrows for clarity
ax2.plot(
t_coords,
y,
linestyle="-",
color="dimgray",
linewidth=1.5,
alpha=0.7,
zorder=3,
label=f"Sample Path {i + 1}",
)
ax2.set_xlabel("$t$: flow step")
ax2.set_title(r"Predicted velocity field ${FM}_{\theta}(x_t, t)$ with pathlines")
ax2.grid(False)
# Plot target data distribution x1
# Sample target data distribution x1
x_t = torch.from_numpy(data_x_0).float().to(DEVICE).unsqueeze(-1)
dt: float = 0.01 # Step size for Euler integration
n_flow_steps = int(1 / dt)
# Generate the flow process
with torch.inference_mode():
for i in range(n_flow_steps):
t = torch.full_like(x_t, i * dt, device=DEVICE)
x_t = x_t + flow_matching_model(t=t, x_t=x_t.to(DEVICE)) * dt
data_x_1 = x_t.cpu().numpy().squeeze()
ax3.set_title(r"Predicted Data $\hat{X}_1$")
ax3.hist(
data_x_1.flatten(),
bins=100,
alpha=0.5,
label=r"Predicted Data $\hat{X}_1$",
color="tab:blue",
density=True,
orientation="horizontal",
)
ax3.set_xlabel("density")
ax3.yaxis.set_label_position("right")
ax3.set_ylabel("x")
# Also show y-axis values on the right side
ax3.yaxis.set_visible(True)
ax3.yaxis.set_tick_params(labelright=True)
plt.show()
del (
fig,
ax1,
ax2,
ax3,
x_min,
x_max,
x_min_wide,
x_max_wide,
data_size, # fmt: skip
data_x_0,
data_x_1,
dt,
n_flow_steps,
img_hist_size,
flow_field_img_bins, # fmt: skip
flow_field_img_x_bin_edges,
flow_field_for_path_bins,
flow_field_for_path_x_bin_edges, # fmt: skip
y,
idx,
x0,
y0,
u,
i,
t,
x_t, # fmt: skip
)
#
Sampling from the trained model
At inference time we can sample a starting point $x_0$ from the noise distribution $π_0$ and then use the predicted velocity field ${FM}_{\theta}(x_t, t)$ to iteratively move (integrate) the sample towards a sample $\hat{x}_1$ from the target distribution $π_1$.
The code below starts with noise $ x_0 \sim \mathcal{N}(0, 1)$ and integrates the learned ODE using the simple Euler method . The Euler method is a simple integration method that at each step $t$ takes the velocity field prediction ${FM}_{\theta}(x_t, t)$ at the current position $x_t$ and moves the sample a small step $dt$ in the direction of the velocity field.
# Illustration on how to sample x_1 from x_0 using the learned velocity field
nb_steps = 15
path_x = np.zeros(nb_steps + 1) # Array to store the full sampled path
t_steps = np.linspace(0, 1, nb_steps + 1) # Steps $t$ in the range [0, 1]
# x_0 starting point (Pre-selected here for the example, but ideally x_0 ~ N(0, I
x_0 = torch.Tensor([[0.85]]).to(DEVICE)
with torch.inference_mode():
flow_matching_model.eval()
x_t = x_0 # Initialize the sample at the starting point
path_x[0] = x_t.squeeze().cpu().numpy()
# Integrate the velocity field using Euler integration from t=0 to t=1
for i in range(nb_steps):
t = t_steps[i] # Current step $t$
dt = t_steps[i + 1] - t_steps[i] # Step size
t_batch = torch.Tensor([[t]]).to(DEVICE) # Expand the step to a batch dimension
# Get the velocity field prediction at the current position and time step and move the sample a small step dt in the direction of the velocity field
x_t = x_t + flow_matching_model(t=t_batch, x_t=x_t) * dt
path_x[i + 1] = x_t.squeeze().cpu().numpy()
display(HTML(pd.DataFrame({"t": t_steps, "x": path_x}).transpose().to_html()))
We can illustrate this sampled path in the following animation which shows the integration from the noise sample $x_0$ towards the target distribution $\hat{x}_1$ using the predicted velocity field ${FM}_{\theta}(x_t, t)$ above. The velocity field is visualized as a heatmap where the vertical axis represents the position of the sample $x_t$ and the horizontal axis represents the flow step $t$ going from 0 on the left to 1 on the right. Red means a positive velocity (sample pushed up towards higher $x$) and blue means a negative velocity (sample pulled down towards lower $x$).
Notice that while we trained on straight-line paths, the sampled path it not necessarily a straight line. This is because we don't learn the paths directly but learn the unconditioned velocity field by training on a large set of straight-line reference paths.
# Visualize animation of the flow matching integration (denoising) using Euler integration
# Set up the plot
fig, (ax1, ax2, ax3) = plt.subplots(
1,
3,
figsize=(12, 5),
gridspec_kw={"width_ratios": [1, 5, 1]},
sharey=True,
dpi=100,
constrained_layout=True,
)
x_min, x_max = -2.2, 2.2 # Narrow view for flow field image
x_all_steps = np.linspace(-3, 3, 1000)
# Sample set of noise and target points
data_size: int = 100_000
np.random.seed(1) # Set random seeds for reproducibility
data_x_0 = np.random.randn(data_size)
# Plot the velocity field ##########################################
# Plot Noise samples X0
# Plot Noise distribution π₀
ax1.set_title("Noise Distribution π₀")
pdf_noise = scipy.stats.norm.pdf(x_all_steps, loc=0, scale=1)
ax1.plot(pdf_noise, x_all_steps, label="PDF Noise π₀", color="tab:orange")
ax1.fill_betweenx(x_all_steps, pdf_noise, alpha=0.4, color="tab:orange")
ax1.invert_xaxis()
ax1.set_ylabel("x")
ax1.set_xlabel("")
# ax11.set_xlim(0, 1)
# Plot final distribution x1
ax3.set_title("Target Distribution π₁")
pdf_target = mixture_pdf(x=x_all_steps)
ax3.plot(pdf_target, x_all_steps, label="PDF Target π₁", color="tab:blue")
ax3.fill_betweenx(x_all_steps, pdf_target, alpha=0.4, color="tab:blue")
ax3.set_xlabel("")
ax3.yaxis.set_label_position("right")
ax3.set_ylabel("x")
# Also show y-axis values on the right side
ax3.yaxis.set_visible(True)
ax3.yaxis.set_tick_params(labelright=True)
# ax13.set_xlim(0, 1)
# Plot Flow Field
n_flow_steps = 100
# Set up the flow field histogram parameters
img_hist_size = 200
# Narrow view for flow field image
flow_field_img_bins = np.zeros((n_flow_steps + 1, img_hist_size))
flow_field_img_x_bin_edges = np.linspace(x_min, x_max, img_hist_size + 1)
flow_field_img_x_bin_centers = (
torch.from_numpy((flow_field_img_x_bin_edges[:-1] + flow_field_img_x_bin_edges[1:]) / 2)
.float()
.to(DEVICE)
.unsqueeze(-1)
)
# Build up the histogram of the reference paths by going over the discretized t-bins
# We're building 2 histograms:
# - `flow_field_img_bins` narrow view for the flow field image, might have NaNs if there are no samples in a bin
# - `flow_field_for_path_bins` wide view for the pathlines, avoid NaNs by setting the velocity to 0 if there are no samples in a bin
with torch.inference_mode():
ts = torch.linspace(0, 1, n_flow_steps + 1, device=DEVICE)
for i in range(n_flow_steps + 1):
t = ts[i].expand_as(flow_field_img_x_bin_centers).to(DEVICE)
# Get the model's prediction for the velocity field at the bin centers
flow_field_img_bins[i] = flow_matching_model(t=t, x_t=flow_field_img_x_bin_centers).cpu().numpy().squeeze()
# Plot the flow field
max_abs_flow_field = np.nanmax(np.abs(flow_field_img_bins))
color_norm = matplotlib.colors.TwoSlopeNorm(vmin=-max_abs_flow_field, vcenter=0.0, vmax=max_abs_flow_field)
im = ax2.imshow(
flow_field_img_bins.T,
aspect="auto",
origin="lower",
extent=[0, 1, x_min, x_max],
cmap="coolwarm",
norm=color_norm,
)
ax2.set_ylim(x_min, x_max)
cbar = fig.colorbar(
im,
ax=[ax1, ax2, ax3],
orientation="horizontal",
fraction=0.08,
aspect=40,
pad=0.04,
)
cbar.set_label("Velocity field value (red pushes up, blue pulls down)")
# overlay original line on top of arrows for clarity
step_draw = ax2.scatter(
[0],
[x_0.squeeze().cpu().numpy()],
color="tab:blue",
zorder=4,
alpha=0.9,
)
(line_draw,) = ax2.plot(
[0],
[x_0.squeeze().cpu().numpy()],
linestyle=":",
color="tab:blue",
linewidth=1,
alpha=0.5,
zorder=3,
)
text_draw = ax2.text(
0.77,
-1.58,
f" t = 0.00\nstep = 0\n x = {path_x[0]:.2f}",
fontsize=16,
color="black",
fontfamily="monospace",
horizontalalignment="left",
verticalalignment="center",
bbox={"facecolor": "white", "alpha": 0.5, "pad": 8},
)
ax2.set_xlabel("$t$: flow step")
ax2.set_title(r"Euler integration of the predicted velocity field ${FM}_{\theta}(x_t, t)$")
ax2.grid(True)
ax2.legend(
[step_draw],
["Euler Integration Step"],
loc="lower left",
)
def update_animation(frame: int, step_draw, line_draw, text_draw, nb_steps):
"""
Update the figure to show an animation of the integration of the velocity field.
"""
xs = path_x[: frame + 1]
ts = t_steps[: frame + 1]
step_draw.set_offsets(np.stack([ts, xs]).T)
line_draw.set_xdata(ts)
line_draw.set_ydata(xs)
text_draw.set_text(f" t = {ts[-1]:.2f}\nstep = {min(frame, nb_steps):d}\n x = {xs[-1]:.2f}")
return (step_draw, line_draw)
# Create figure
ani = FuncAnimation(
fig=fig,
func=functools.partial(
update_animation,
step_draw=step_draw,
line_draw=line_draw,
text_draw=text_draw,
nb_steps=nb_steps,
),
frames=nb_steps + 2,
interval=1,
)
ani.save(str(ANIMATION_FILE), writer="ffmpeg", fps=3)
plt.close(fig)
# Display the animation in the notebook
# Embed the GIF directly in the notebook by encoding the bytes as base64, this way it should hopefully also be exported
with ANIMATION_FILE.open("rb") as f:
gif_data = f.read()
display(
HTML(f"""
<video alt="Flow matching path integration" loop="true" autoplay="autoplay" muted>
<source type="video/mp4" src="data:video/mp4;base64,{base64.b64encode(gif_data).decode()}">
</video>
""")
)
del (fig, ax1, ax2, ax3, x_min, x_max, data_size, data_x_0, n_flow_steps, img_hist_size, # fmt: skip
flow_field_img_bins, flow_field_img_x_bin_edges, flow_field_img_x_bin_centers, max_abs_flow_field, # fmt: skip
color_norm, im, cbar, step_draw, line_draw, text_draw, update_animation, ani, gif_data) # fmt: skip
#
We can also take a large sample from the model $\hat{X}_1$ and reconstruct the target distribution $\pi_1$.
We'll define a
sample
function that will generate samples by integrating the learned vector field using
Euler integration
.
We'll then plot the target distribution and the reconstructed samples.
@torch.inference_mode()
def sample(
n_samples: int, # Number of samples to generate
model: FlowMatchingModel, # The flow matching model
nb_steps: int, # Number of Euler integration steps
) -> torch.Tensor:
"""Generates samples by integrating the learned vector field using Euler integration."""
ts = torch.linspace(0, 1, nb_steps + 1, device=DEVICE)
x_t = torch.randn(n_samples, data_dim).to(DEVICE) # Sample x_0 ~ N(0, I)
for i in range(nb_steps): # Euler integration from t=0 to t=1 (last step happens just before t=1)
t = ts[i] # Current step $t$
dt = ts[i + 1] - ts[i] # Step size
t_batch = t.expand(n_samples).unsqueeze(-1)
# Move the sample a small step dt in the direction of the velocity field
x_t = x_t + model(t=t_batch, x_t=x_t) * dt
return x_t # Final sample x_1
# Plot data distribution. This is the TARGET distribution (π₁)
fig, ax = plt.subplots(1, 1, figsize=(8, 3), constrained_layout=True, dpi=100)
# Plot the target distribution π₁
x_all_steps = np.linspace(-2.5, 2.5, 1000)
pdf_target = mixture_pdf(x=x_all_steps)
ax.plot(x_all_steps, pdf_target, label=r"PDF Target $\pi_1$", color="tab:purple")
ax.fill_between(x_all_steps, pdf_target, alpha=0.4, color="tab:purple")
# Plot the samples
x1_samples = sample(n_samples=100_000, model=flow_matching_model, nb_steps=50).cpu().numpy().flatten()
sns.histplot(
x=x1_samples,
bins=100,
color="tab:blue",
kde=False,
alpha=0.8,
ax=ax,
stat="density",
label=r"Reconstructed samples $\hat{X}_1$",
)
ax.legend()
ax.set_title(r"Target Distribution ($\pi_1$) vs Reconstructed Samples ($\hat{X}_1$)")
ax.set_xlabel("x")
ax.set_ylabel("density")
plt.show()
del fig, ax, x_all_steps, pdf_target, x1_samples
#
As a final illustration, let's illustrate the the path density between the starting noise samples $\hat{X}_0$ and the final reconstructed samples $\hat{X}_1$ by sampling a large number of paths from the noise distribution $\pi_0$ to the target distribution $\pi_1$.
# Path density
@torch.inference_mode()
def sample_paths(
n_samples: int,
model: FlowMatchingModel,
nb_steps: int, # Number of Euler integration steps
) -> torch.Tensor:
"""Generates samples by integrating the learned vector field, keeping track of the intermediate steps."""
x_all_steps = torch.zeros(n_samples, nb_steps + 1).to(DEVICE)
ts = torch.linspace(0, 1, nb_steps + 1, device=DEVICE)
x_t = torch.randn(n_samples, 1).to(DEVICE) # Sample x_0 ~ N(0, I)
x_all_steps[:, 0] = x_t.squeeze()
for i in range(nb_steps): # Euler integration from t=0 to t=1
t = ts[i] # Current step $t$
dt = ts[i + 1] - ts[i] # Step size
t_batch = t.expand(n_samples).unsqueeze(-1) # Expand the step to a batch dimension
# Move the sample a small step dt in the direction of the velocity field
x_t = x_t + model(t=t_batch, x_t=x_t) * dt
x_all_steps[:, i + 1] = x_t.squeeze()
return x_all_steps
nb_paths = 100_000
nb_steps = 200
paths = sample_paths(n_samples=nb_paths, model=flow_matching_model, nb_steps=nb_steps).cpu().numpy()
# Illustration of the sampled reference paths
# Set up the plot
fig, (ax1, ax2, ax3) = plt.subplots(
1,
3,
figsize=(12, 4),
gridspec_kw={"width_ratios": [1, 5, 1]},
sharey=True,
dpi=100,
)
fig.subplots_adjust(wspace=0)
x_min, x_max = -2.5, 2.5
# Sample set of noise and target points
np.random.seed(1) # Set random seeds for reproducibility
torch.manual_seed(1)
# Plot the full data distribution ##################################
# Plot Noise samples X0
ax1.set_title("Noise Samples X₀")
ax1.hist(
paths[:, 0],
bins=100,
alpha=0.5,
label="Noise π₀",
color="tab:orange",
density=True,
orientation="horizontal",
)
ax1.invert_xaxis()
ax1.set_ylabel("x")
ax1.set_xlabel("density")
# Plot target data distribution x1
ax3.set_title(r"Reconstructed samples $\hat{X}_1$")
ax3.hist(
paths[:, -1],
bins=100,
alpha=0.5,
label=r"Reconstructed samples $\hat{X}_1$",
color="tab:blue",
density=True,
orientation="horizontal",
)
ax3.set_xlabel("density")
ax3.yaxis.set_label_position("right")
ax3.set_ylabel("x")
# Also show y-axis values on the right side
ax3.yaxis.set_visible(True)
ax3.yaxis.set_tick_params(labelright=True)
# Plot path density
# Set up the path density histogram parameters
img_hist_x_size = 480
path_density_bins = np.zeros((nb_steps + 1, img_hist_x_size))
flow_field_img_x_bin_edges = np.linspace(x_min, x_max, img_hist_x_size + 1)
# Build up the histogram of the reference paths by going over the discretized t-bins
for i in range(nb_steps + 1):
path_density_bins[i] = np.histogram(paths[:, i], bins=flow_field_img_x_bin_edges)[0]
im = ax2.imshow(path_density_bins.T, aspect="auto", origin="lower", extent=[0, 1, x_min, x_max], cmap="viridis")
ax2.set_xlabel("$t$: flow step")
ax2.set_title(r"Path density of paths sampled using Euler integration of ${FM}_{\theta}(x_t, t)$")
ax2.grid(False)
plt.tight_layout()
plt.show()
del (fig, ax1, ax2, ax3, x_min, x_max, nb_paths, nb_steps, paths, img_hist_x_size, # fmt: skip
path_density_bins, flow_field_img_x_bin_edges, i, im) # fmt: skip
#
Summary
To conclude we've implemented a simple flow matching model and trained it on 1D toy data. The 1D toy data allowed us to easily visualize the flow matching model, the velocity field, and the sampled paths.
In real-world applications the target distribution is not known and more complex, resulting in a more complex vector field, which typically requires using more complex models to learn the vector field and more sophisticated sampling strategies to sample from the model.
References and further reading
- Flow matching paper: Flow Matching for Generative Modeling
- Rectified flow paper: Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow
- Flow Matching Guide and Code
- An introduction to flow matching
# Python package versions used
%load_ext watermark
%watermark --python
%watermark --iversions
#
This post at peterroelants.github.io is generated from an IPython notebook file. Link to the full IPython notebook file