Flow Matching Generation 原理与Demo

最新更新于: 2026年3月17日晚上9点14分

参考论文1. Flow Matching for Generative Modeling - 2022

Flow Matching原理

Flow Matching是一个非常有名的生成模型框架,其他的生成模型还有:VAE、GAN、Diffusion (DDPM)等,Flow Matching优势在于训练效率与生成效率都非常高。本文参考论文1从零开始分析原理,最后给出两个Demo测试该算法。

生成模型要解决什么问题?

对于某个数据集 DRd\mathcal{D}\subset\mathbb{R}^d,其中 dd 表示样本 xDx\in\mathcal{D} 的维度,例如

  • 1080P的三通道图像维度就是 d=1920×1080×3d=1920\times 1080\times 3
  • 机器人控制就是全部可控制电机的个数 d=ndofd=n_{dof}
  • U[0,1]\mathcal{U}[0,1] 均匀分布中采样得到的点集 d=1d=1
  • dd 维正态分布 N(μ,Σ)\mathcal{N}(\mu,\Sigma) 中采样得到的点集

我们希望得到找到 D\mathcal{D} 中数据所满足的分布,称之为真实分布 P1P_1,从而能直接从该分布中采样进而生成出新样本 xP1x\sim P_1,达到数据生成的目的

上述的各种生成模型有不同的方法近似得到该策略,Flow Matching给出了一种从噪声分布(标准正态分布)到目标分布的算法,其背后的数学/物理原理非常巧妙很有意思

前置定义

设数据集为 DRd\mathcal{D}\subset\mathbb{R}^d,数据记为 x1Dx_1\in\mathcal{D},其中 dd 为数据维度,x1x_1 服从 P1P_1 真实分布,我们希望通过训练的方法近似求出该真实分布

为了解决这个问题,我们直接从直观的思路上来想,是否能每次对噪声分布做微小的移动,从而逐渐移动到真实分布上,下文的“粒子”就是机器学习中的数据,用粒子只是更贴近物理术语好理解些,具体而言:

定义1(概率密度路径, probability density path):与时间 t[0,1]t\in[0,1] 相关的概率密度函数 pt(x):=p(t,x):[0,1]×RdRp_t(x):=p(t,x): [0,1]\times \mathbb{R}^d\to\mathbb{R} 称为概率密度路径

定义2(流, flow)ϕt(x):=ϕ(t,x):[0,1]×RdRd\phi_t(x):=\phi(t,x):[0,1]\times\mathbb{R}^d\to\mathbb{R}^d 为粒子 xx 在时间 tt 下移动到的位置

定义3(时变向量场, time-dependent vector field, VF)ut(x):=u(t,x):[0,1]×RdRdu_t(x):=u(t,x):[0,1]\times\mathbb{R}^d\to\mathbb{R}^dtt 时刻下粒子 xx 的移动方向

不难发现,流就是一个粒子空间位置,时变向量场就是粒子速度,二者满足如下关系(微分和积分关系)

ddtϕt(x)=ut(ϕt(x))\frac{\mathrm{d}}{\mathrm{d} t}\phi_t(x) = u_t(\phi_t(x))

数学上称之为微分同胚(diffeomorphic)该映射具有连续可微的性质,换句话说就是一一对应

我们还可以发现流 ϕt\phi_t 描述了每个粒子在每个时刻下的位置,因此对于 tt 时刻下粒子的出现概率 ptp_t,给出 p0p_0ϕt\phi_t,则不难得到 pt=[ϕt]p0p_t=[\phi_t]_{*}p_0,数学上称该算子 [][\cdot]_{*}前推算子,具体表达式见论文中Eq.4

我们可以固定一个初始随机分布 p0p_0(标准正态分布) 通过某个流 ϕt\phi_t 到达最后的 p1p_1,因此问题转化为找到 ϕt\phi_t,而流和时变向量场又是一一对应关系,求解 utu_t 变为我们的终极问题

定义4(Flow Matching, FM):对于神经网络参数化的函数 vt(x;θ):[0,1]×RdRdv_t(x;\theta):[0,1]\times\mathbb{R}^d\to\mathbb{R}^d 我们期望最小化目标

LFM(θ):=Et,pt(x)vt(x;θ)ut(x)2(1)\mathcal{L}_{\text{FM}}(\theta):=\mathbb{E}_{t,p_t(x)}||v_t(x;\theta)-u_t(x)||^2 \tag{1}

其中 tU[0,1]t\sim\mathcal{U}[0,1] 均匀分布,xpt(x)x\sim p_t(x),该目标称为Flow Matching目标。

P.S. 这个目标其实是向量场的近似,但却被称为流匹配,可能是求解流是最终目标,而向量场则是求解其的等价替代品

本论文提出关键核心就是将FM转为可求解的条件流匹配(Condition FM),下面详细介绍

Condition Flow Matching两个核心定理

直接求解 utu_t 没有头绪,但是将其通过条件分布边缘化(积分)是否可以得到呢,于是引出如下定理

定理1(边缘化条件概率路径)

x1Dx_1\in\mathcal{D},条件概率路径 pt(xx1)p_t(x|x_1) 由条件向量场 ut(xx1)u_t(x|x_1) 得到,对于真实分布 q(x1)q(x_1),向量场可通过边缘化得到

ut(x)=ut(xx1)pt(xx1)q(x1)pt(x)dx1u_t(x) = \int u_t(x|x_1)\frac{p_t(x|x_1)q(x_1)}{p_t(x)}\mathrm{d}x_1

证明:我们利用到两个重要公式

全概率pt(x)=pt(xx1)q(x1)dx1连续性方程pt(x)t=div(ut(x)pt(x))\begin{aligned} \text{全概率}\quad p_t(x)=\int p_t(x|x_1)q(x_1)\mathrm{d}x_1 \\ \text{连续性方程}\quad \frac{\partial p_t(x)}{\partial t} = -\text{div}(u_t(x)p_t(x)) \end{aligned}

连续性方程为流体力学中的重要公式,描述了场(流速)utu_t 和概率路径(流体密度)ptp_t 的关系,utptu_tp_t 为概率通量(质量流量,流过单位截面的概率质量),该公式表示局部密度变化等于流入和流出该区域的通量差

div(ut(x)pt(x))=连续性方程tpt(x)=全概率 tpt(xx1)q(x1)dx1=连续性方程 div(ut(xx1)pt(xx1))q(x1)dx1= divut(xx1)pt(xx1)q(x1)dx1\begin{aligned} -\text{div}(u_t(x)p_t(x))\xlongequal{\text{连续性方程}}\frac{\partial}{\partial t}p_t(x)\xlongequal{\text{全概率}}&\ \int\frac{\partial}{\partial t}p_t(x|x_1)q(x_1)\mathrm{d}x_1\\ \xlongequal{\text{连续性方程}}&\ \int-\text{div}(u_t(x|x_1)p_t(x|x_1))q(x_1)\mathrm{d}x_1\\ =&\ -\text{div}\int u_t(x|x_1)p_t(x|x_1)q(x_1)\mathrm{d}x_1\\ \end{aligned}

因此在常见正则性假设下,可取

ut(x)pt(x)=ut(xx1)pt(xx1)q(x1)dx1u_t(x)p_t(x)=\int u_t(x|x_1)p_t(x|x_1)q(x_1)\mathrm{d}x_1

两边同除 pt(x)p_t(x) 即得结论

QED

定理2(FM与CFM关于网络参数具有相同梯度)

q(x1)q(x_1) 为真实分布,x1qx_1\sim q 为真实数据,则条件流匹配(Condition Flow Matching, CFM)最小化目标为

LCFM(θ):=Et,q(x1),pt(xx1)vt(x;θ)ut(xx1)2(2)\mathcal{L}_{\text{CFM}}(\theta):=\mathbb{E}_{t,q(x_1),p_t(x|x_1)}||v_t(x;\theta)-u_t(x|x_1)||^2 \tag{2}

式(1)与(2)关系为 LFM(θ)=LCFM(θ)\nabla\mathcal{L}_{\text{FM}}(\theta)=\nabla\mathcal{L}_{\text{CFM}}(\theta),即关于 θ\theta 的导数相同,梯度下降法求解 vt(x;θ)v_t(x;\theta) 二者等价

证明
由于 ab2=a22a,b+b2||a-b||^2=||a||^2-2\langle a,b\rangle+||b||^2,则

vt(x)ut(x)2= vt22vt,ut+ut2vt(x)ut(xx1)2= vt22vt,ut(xx1)+ut(xx1)2\begin{aligned} ||v_t(x)-u_t(x)||^2=&\ ||v_t||^2-2\langle v_t,u_t\rangle + ||u_t||^2\\ ||v_t(x)-u_t(x|x_1)||^2=&\ ||v_t||^2-2\langle v_t,u_t(x|x_1)\rangle + ||u_t(x|x_1)||^2 \end{aligned}

由于最后关于 θ\theta 求导,仅考虑包含 θ\theta 的项,即前两项,只需分别证明期望意义下相等:

第一项

Et,pt(x)vt(x)2= vt(x)2pt(x)dx=vt(x)2(pt(xx1)q(x1)dx1)dx=Fubini定理 vt(x)2pt(xx1)q(x1)dx1dx= Et,q(x1),pt(xx1)vt(x)2\begin{aligned} \mathbb{E}_{t,p_t(x)}||v_t(x)||^2=&\ \int ||v_t(x)||^2p_t(x)\mathrm{d}x=\int ||v_t(x)||^2\left(\int p_t(x|x_1)q(x_1)\mathrm{d}x_1\right)\mathrm{d}x\\ \xlongequal{\text{Fubini定理}}&\ \iint||v_t(x)||^2p_t(x|x_1)q(x_1)\mathrm{d}x_1\mathrm{d}x\\ =&\ \mathbb{E}_{t,q(x_1),p_t(x|x_1)}||v_t(x)||^2 \end{aligned}

第二项

Et,pt(x)vt(x),ut(x)= vt(x),ut(x)pt(x)dx=定理1 vt(x),ut(xx1)pt(xx1)q(x1)pt(x)dx1pt(x)dx= vt(x),ut(xx1)pt(xx1)q(x1)dx1dx=Fubini定理 vt(x),ut(xx1)pt(xx1)q(x1)dx1dx= Et,q(x1),pt(xx1)vt(x),ut(xx1)\begin{aligned} \mathbb{E}_{t,p_t(x)}\langle v_t(x),u_t(x)\rangle=&\ \int\langle v_t(x),u_t(x)\rangle p_t(x)\mathrm{d}x\\ \xlongequal{\text{定理1}}&\ \int\left\langle v_t(x),\int u_t(x|x_1)\frac{p_t(x|x_1)q(x_1)}{p_t(x)}\mathrm{d}x_1\right\rangle p_t(x)\mathrm{d}x\\ =&\ \int\left\langle v_t(x),\int u_t(x|x_1)p_t(x|x_1)q(x_1)\mathrm{d}x_1\right\rangle \mathrm{d}x\\ \xlongequal{\text{Fubini定理}}&\ \iint \langle v_t(x),u_t(x|x_1)\rangle p_t(x|x_1)q(x_1)\mathrm{d}x_1\mathrm{d}x\\ =&\ \mathbb{E}_{t,q(x_1),p_t(x|x_1)}\langle v_t(x),u_t(x|x_1) \rangle \end{aligned}

综上:LFM(θ)=LCFM(θ)+C\mathcal{L}_{\text{FM}}(\theta)=\mathcal{L}_{\text{CFM}}(\theta)+C,则 θLFM=θLCFM\nabla_{\theta}\mathcal{L}_{\text{FM}}=\nabla_{\theta}\mathcal{L}_{\text{CFM}}

QED

如何训练?

我们证明了最重要的定理条件流匹配定理2,如何使用它来训练呢?随机采样数据集中的一个样本 x1x_1,虽然目标向量场 utu_t 很难获得,但是条件向量场 ut(xx1)u_t(x|x_1) 确实可以直接构造得到的,这个构造只需找到两个边界条件做线性插值即可,论文中将其称之为最优传输插值(Optimal Transport (OT) interpolant),因为我们可以观察两个特殊时间点的分布:

  • t=0t=0 时,p0(z)N(0,I)p_0(z)\sim\mathcal{N}(0,I)(标准正态分布)
  • t=1t=1 时,p1(xx1)=N(x1,εI)p_1(x|x_1)=\mathcal{N}(x_1,\varepsilon I)(一个均值为 x1x_1、协方差为 εI\varepsilon Iε0\varepsilon\approx 0 的正态分布)

于是容易通过线性插值构造出OT路径,也称条件流(Conditional flow)

ψt(z)=(1(1ε)t)z+tx1,其中 zN(0,I)(3)\psi_{t}(z) = (1-(1-\varepsilon)t)z+tx_1,\quad \text{其中}\ z\sim\mathcal{N}(0,I) \tag{3}

由于流和场就是微分关系,因此条件向量场就是

ut(xx1)=ddtψt(z)=x1(1ε)z(4)u_t(x|x_1) = \frac{\mathrm{d}}{\mathrm{d}t}\psi_t(z) = x_1 - (1-\varepsilon)z \tag{4}

其中 x=ψt(z)x = \psi_{t}(z),由式(3)可知 x=(1(1ε)t)z+tx1z=xtx11(1ε)tx = (1-(1-\varepsilon)t)z+tx_1\Rightarrow z = \dfrac{x-tx_1}{1-(1-\varepsilon)t},带入式(4)可得

ut(xx1)=x1(1ε)xtx11(1ε)t=x1(1ε)x1(1ε)tu_t(x|x_1) = x_1-(1-\varepsilon)\frac{x-tx_1}{1-(1-\varepsilon)t} = \frac{x_1-(1-\varepsilon)x}{1-(1-\varepsilon)t}

x=ψt(z)x=\psi_t(z) 和式(4)带入CFM损失函数 式(2)中,可以得到我们训练神经网络的最小化目标

LCFM= Et,q(x1),pt(xx1)vt(x;θ)ut(xx1)2= Et,q(x1),pt(xx1)vt(ϕt(z);θ)(x1(1ε)z)2\begin{aligned} \mathcal{L}_{\text{CFM}} =&\ \mathbb{E}_{t,q(x_1),p_t(x|x_1)}||v_t(x;\theta)-u_t(x|x_1)||^2\\ =&\ \mathbb{E}_{t,q(x_1),p_t(x|x_1)}||v_t(\phi_t(z);\theta)-(x_1-(1-\varepsilon)z)||^2 \end{aligned}

具体训练中,训练集随机选择batch集合,其中样本为 x1x_1,从均匀分布中采样时间 tU[0,1]t\sim\mathcal{U}[0,1],带入 LCFM\mathcal{L}_{\text{CFM}} 中,计算梯度即可

vt(x;θ)v_t(x;\theta) 的网络选择可以是:MLP(简单的拟合),UNet+ResNet(图像生成),Transformer(复杂的拟合,机器人控制,如OmniXtreme)

如何推理?

我们的目标式通过 p0p1p_0 \to p_1,中间的过程是通过每个时刻 tt 下的 ut(x)u_t(x) 给出,所以最简单的方法就是固定时刻步进长度 Δt\Delta t,然后每次步进这个长度的距离即可,这个方法也称为Euler法

xt+Δt=xt+vt(xt)Δt,x0N(0,I)x_{t+\Delta t}=x_t+v_t(x_t)\cdot \Delta t,\quad x_0\sim\mathcal{N}(0,I)

这就是ODE求解器,也是求解数值积分面积的方法,但是Euler法还是太过于暴力且不精准,论文中使用的是Dormand-Prince method(dopri5,多尔曼-普林斯5阶近似方法),也就是对Runge-Kutta method(RK4, 龙格-库塔4阶近似方法)的改进,能自动调整 Δt\Delta t 的大小,在梯度较小式增大 Δt\Delta t,从而加快生成速度,且精确比Euler法更高

例子

下面这些例子中我们就用简单的欧拉法,均匀设定步长,当推进次数为 nn 时,步长为 Δt=1/n\Delta t=1/n,来体现不同步长时候生成的结果

依赖安装包pytorch,matplotlib,任意版本python,最好别低于3.8

代码均为Gemini 3.1 Pro生成,经过调试得到,训练显卡为RTX 4080

二维棋盘图分布

d=2d=2,我们的目标分布是 [2,2]2[-2,2]^2 上长度为 11 均匀分布的正方形棋盘,下左图所示

目标分布 生成分布
target distribution generation

网络使用的是5层512神经元的MLP,训练50秒得到,完整代码如下

MNIST手写数字生成

训练这个最好先确定PyTorch有显卡加速不然太慢,训练用时21分钟,网络使用UNet+ResNet,生成效果如下

step0 step1 step2 step3 step4
step0 step1 step2 step3 step4
step10 step100 step300 step1000
step10 step100 step300 step1000

Flow Matching Generation 原理与Demo
https://wty-yy.xyz/posts/303/
作者
wty
发布于
2026年3月16日
许可协议