連続時間拡散モデルとフローマッチング理論

このノートでは、連続時間拡散と Flow Matching を同じ座標系で理解します。狙いは「式を覚える」ことではなく、

を一貫して追えるようにすることです。

Flow Matching は、途中の「正しい動き方」を直接学ぶ

拡散モデルではノイズ除去として逆過程を組み立てますが、Flow Matching ではノイズ点からデータ点へ至る経路上の速度場そのものを学びます。何を予測するかが違うだけで、連続時間生成としてはかなり近い世界にいます。

このノートでは、まずノイズとデータを結ぶ path を定義し、次にその path に沿う目標速度 utu_t を回帰します。そのあとで time conditioning の意味、path の選び方、OT-CFM が何を工夫しているかまでを見て、連続時間生成の見取り図を作ります。

読み筋は「どの道を通るか」と「その道でどう動くか」です

Flow Matching の設計自由度は、path と vector field の 2 箇所にあります。ここが見えると、線形ブリッジと別のブリッジの違いも、単なる式差分ではなく学習対象の違いとして読めます。

最初に path を並べるのは、生成器より経路設計が主役だからです

ノイズからデータへ行くなら何でもよいわけではありません。どの中間点を通るかで、学習しやすさも速度場の形も変わります。

速度場を学ぶとは、その時刻でどちらへ動くかを当てること

スコアベースモデルが確率上昇方向を学ぶのに対し、Flow Matching は経路に沿った目標速度を直接回帰します。ここがこの notebook のいちばん大事な違いです。

時刻入力が必要なのは、同じ場所でも「いま何時か」で動き方が変わるから

連続時間生成では、初期と終盤で望ましい速度が違います。t を消すと、その違いを表せなくなります。

OT-CFM は loss を変えるより、ペアの作り方を変える

ここは初学者が誤解しやすい点です。何を回帰するかは同じでも、どの (x_0, x_1) を結びつけるかを工夫すると、学習しやすい flow が作れるようになります。

まずはノイズ点とデータ点を橋で結ぶ

最初の節では、ノイズ源とターゲット分布を置き、線形 path と別の path がどう違うかを可視化します。

import math
import random
import statistics
import time

Flow Matching の基本は、ノイズ点 x0x_0 とデータ点 x1x_1 を結ぶ経路 xtx_t を設計し、対応する目標速度 utu_t を回帰することです。

最初に記号をそろえます。

学習目的は

minθ  Et,x0,x1[vθ(xt,t)ut(x0,x1,t)2]\min_\theta\; \mathbb{E}_{t,x_0,x_1}\left[\|v_\theta(x_t,t)-u_t(x_0,x_1,t)\|^2\right]

です。二乗誤差回帰の一般事実として、各点 (x,t)(x,t) における最適予測は条件付き期待値になります。したがって最適解は

v(x,t)=E[utxt=x]v^*(x,t)=\mathbb{E}[u_t\mid x_t=x]

であり、「各ペア速度をそのまま暗記」ではなく「同じ (x,t)(x,t) に集まる速度の平均場」を学習している、と読むのが正確です。

学習後は ODE

dxdt=vθ(x,t),x(0)p0\frac{dx}{dt}=v_\theta(x,t),\quad x(0)\sim p_0

を積分して生成します。

連続時間拡散モデルとの接続(上級・初回は読み飛ばしてもOK):

前向き過程を SDE

dx=ft(x)dt+gtdWtdx = f_t(x)dt + g_t dW_t

で定めたとき、ここで

です。代表的な前提(例えば gtg_t が状態非依存の設定)では確率流 ODE を

dxdt=ft(x)12gt2xlogpt(x)\frac{dx}{dt}=f_t(x)-\frac{1}{2}g_t^2\nabla_x \log p_t(x)

と書けます。拡散側は score を経由して速度場を得るのに対し、Flow Matching は経路と教師速度を直接設計して速度場を学習します。

注意: 拡散文献では時間向きを data→noise で書くことが多いのに対し、このノートは Flow Matching の実装導線に合わせて noise→data 向きで記述しています。

注意: このノートでは x_0=ノイズ, x_1=データ と置きます。拡散文献の一部では x0x_0 をデータと書くため、添字の意味を都度確認してください。

まず2次元データを用意します。ノイズ源は標準ガウス、ターゲットは2峰の2次元混合ガウスです。

random.seed(41)

def sample_noise_2d(n: int):
    out = []
    for _ in range(n):
        out.append((random.gauss(0.0, 1.0), random.gauss(0.0, 1.0)))
    return out


def sample_data_2d(n: int):
    out = []
    for _ in range(n):
        if random.random() < 0.5:
            x = random.gauss(-2.2, 0.45)
            y = random.gauss(1.6, 0.55)
        else:
            x = random.gauss(2.1, 0.5)
            y = random.gauss(-1.5, 0.5)
        out.append((x, y))
    return out


def describe_2d(xs):
    mx = statistics.mean(x for x, _ in xs)
    my = statistics.mean(y for _, y in xs)
    sx = statistics.pstdev(x for x, _ in xs)
    sy = statistics.pstdev(y for _, y in xs)
    q1 = sum(1 for x, y in xs if x < 0 and y >= 0) / len(xs)
    q3 = sum(1 for x, y in xs if x >= 0 and y < 0) / len(xs)
    return {'mx': mx, 'my': my, 'sx': sx, 'sy': sy, 'q1': q1, 'q3': q3}


noise_data = sample_noise_2d(3200)
target_data = sample_data_2d(3200)

print('noise stats =', {k: round(v, 4) for k, v in describe_2d(noise_data).items()})
print('target stats=', {k: round(v, 4) for k, v in describe_2d(target_data).items()})

次に2種類のパスを定義します。

  1. 線形パス(独立カップリング版の CFM)
  1. 三角関数パス(連続時間拡散のノイズ減衰を意識した滑らかな経路)

DDPM 的な離散式 x_t = sqrt(alpha_bar_t) x_1 + sqrt(1-alpha_bar_t) x_0 を連続時間で見ると、「データ係数を増やしノイズ係数を減らす」経路設計になります。このノートの三角関数パスは、その直観を連続時間で観察しやすくした簡易版です。

係数の違いも押さえておきます。線形は alpha+sigma=1、三角関数は alpha2+sigma2=1alpha^2+sigma^2=1 です。例えば t=0.5 では線形が (0.5, 0.5)、三角関数が (0.707..., 0.707...) なので、中間の混ざり方が異なります。

注意: ここでの線形パスは x_0, x_1 を独立に引いた組を使っており、厳密な OT カップリングではありません。

参考資料で扱う OT-CFM(Optimal-Transport Conditional Flow Matching)では、損失関数の形自体は同じでも、学習に使う (x0,x1)(x_0, x_1) の組を最適輸送カップリングで作る点が重要です。

直観的には、ノイズ側の点とデータ側の点を「できるだけ無理のない対応」で結び、学習すべきベクトル場を素直にする工夫です。ミニバッチ単位で近似 OT を解く実装が一般的で、画像規模では計算コストと近似精度のトレードオフ設計が実務上のポイントになります。

def linear_bridge(x0, x1, t):
    x = (1.0 - t) * x0[0] + t * x1[0]
    y = (1.0 - t) * x0[1] + t * x1[1]
    vx = x1[0] - x0[0]
    vy = x1[1] - x0[1]
    return (x, y), (vx, vy)


def trig_bridge(x0, x1, t):
    # alpha(0)=0, sigma(0)=1; alpha(1)=1, sigma(1)=0
    alpha = math.sin(0.5 * math.pi * t)
    sigma = math.cos(0.5 * math.pi * t)
    alpha_dot = 0.5 * math.pi * math.cos(0.5 * math.pi * t)
    sigma_dot = -0.5 * math.pi * math.sin(0.5 * math.pi * t)

    x = alpha * x1[0] + sigma * x0[0]
    y = alpha * x1[1] + sigma * x0[1]
    vx = alpha_dot * x1[0] + sigma_dot * x0[0]
    vy = alpha_dot * x1[1] + sigma_dot * x0[1]
    return (x, y), (vx, vy)


def path_length_demo(n=600):
    # 同じカップル(x0,x1)に対し、2パスの移動距離を比較
    total_linear = 0.0
    total_trig = 0.0
    for _ in range(n):
        x0 = noise_data[random.randrange(len(noise_data))]
        x1 = target_data[random.randrange(len(target_data))]

        # linear length (closed form)
        dl = math.sqrt((x1[0] - x0[0]) ** 2 + (x1[1] - x0[1]) ** 2)

        # trig length (numerical)
        prev, _ = trig_bridge(x0, x1, 0.0)
        dt = 1.0 / 120
        s = 0.0
        for i in range(1, 121):
            t = i * dt
            cur, _ = trig_bridge(x0, x1, t)
            s += math.sqrt((cur[0] - prev[0]) ** 2 + (cur[1] - prev[1]) ** 2)
            prev = cur

        total_linear += dl
        total_trig += s

    return total_linear / n, total_trig / n


lin_len, trig_len = path_length_demo()
print('avg path length linear =', round(lin_len, 4))
print('avg path length trig   =', round(trig_len, 4))
print('ratio trig/linear      =', round(trig_len / lin_len, 4))

速度場モデル vtheta(x,t)v_theta(x,t) は、2次元入力 + 時刻に対する低容量回帰器にします。U-Netではなく軽量モデルを使うのは、目的が理論理解だからです。

この設定でも、Flow Matching の学習が「ODEのロールアウトなし」にできることは確認できます。

# feature map for vector field regression
# phi(x,t) -> 8 features
def vf_features(x, y, t):
    return [1.0, x, y, t, x * t, y * t, x * x, y * y]
def vf_predict(theta, x, y, t):
    # theta has 16 params: first 8 for vx, last 8 for vy
    f = vf_features(x, y, t)
    vx = sum(theta[i] * f[i] for i in range(8))
    vy = sum(theta[8 + i] * f[i] for i in range(8))
    return vx, vy
def train_flow_matching(objective='linear', steps=2600, batch_size=128, lr=0.012, seed=2026):
    rng = random.Random(seed)
    theta = [0.0] * 16
    history = []
    bridge = linear_bridge if objective == 'linear' else trig_bridge
    for step in range(steps + 1):
        grads = [0.0] * 16
        loss = 0.0
        for _ in range(batch_size):
            x0 = noise_data[rng.randrange(len(noise_data))]
            x1 = target_data[rng.randrange(len(target_data))]
            t = rng.random()
            (xtx, xty), (utx, uty) = bridge(x0, x1, t)
            pvx, pvy = vf_predict(theta, xtx, xty, t)
            ex = pvx - utx
            ey = pvy - uty
            f = vf_features(xtx, xty, t)
            for i in range(8):
                grads[i] += 2.0 * ex * f[i]
                grads[8 + i] += 2.0 * ey * f[i]
            loss += ex * ex + ey * ey
        inv_bs = 1.0 / batch_size
        grads = [g * inv_bs for g in grads]
        loss *= inv_bs
        # gradient clip
        gnorm = math.sqrt(sum(g * g for g in grads))
        if gnorm > 60.0:
            scale = 60.0 / gnorm
            grads = [g * scale for g in grads]
        wd = 2e-4
        theta = [theta[i] - lr * (grads[i] + wd * theta[i]) for i in range(16)]
        if step % 260 == 0:
            history.append((step, loss, math.sqrt(sum(v * v for v in theta))))
    return theta, history
theta_lin, hist_lin = train_flow_matching(objective='linear')
theta_trig, hist_trig = train_flow_matching(objective='trig')
print('linear objective history:')
for step, loss, norm in hist_lin:
    print('step=', f'{step:04d}', 'loss=', round(loss, 5), '||theta||=', round(norm, 4))
print()
print('trig objective history:')
for step, loss, norm in hist_trig:
    print('step=', f'{step:04d}', 'loss=', round(loss, 5), '||theta||=', round(norm, 4))

path に沿う速度場を学習する

ここでは vtheta(x,t)v_theta(x,t) を回帰し、loss とパラメータがどう動くかを見ます。

def integrate_ode(theta, n=2600, n_steps=90, seed=505, clip=6.0):
    rng = random.Random(seed)
    xs = [(rng.gauss(0.0, 1.0), rng.gauss(0.0, 1.0)) for _ in range(n)]

    dt = 1.0 / n_steps
    for k in range(n_steps):
        t = k * dt
        nxt = []
        for x, y in xs:
            vx, vy = vf_predict(theta, x, y, t)
            xn = x + dt * vx
            yn = y + dt * vy
            # 数値爆発を防ぐためのガード。理論上の流れそのものではなく数値安定化の処理。
            xn = max(-clip, min(clip, xn))
            yn = max(-clip, min(clip, yn))
            nxt.append((xn, yn))
        xs = nxt

    return xs


gen_lin = integrate_ode(theta_lin)
gen_trig = integrate_ode(theta_trig)

print('target stats =', {k: round(v, 4) for k, v in describe_2d(target_data).items()})
print('gen linear   =', {k: round(v, 4) for k, v in describe_2d(gen_lin).items()})
print('gen trig     =', {k: round(v, 4) for k, v in describe_2d(gen_trig).items()})

生成結果を path ごとに比べる

学習後は、どの path を選んだかでサンプル統計がどう変わるかを見ます。

def summary_stat_mse(samples, target_ref):
    # 指標1: 1次/2次統計と象限比率の差分をまとめた粗いMSE
    s = describe_2d(samples)
    t = describe_2d(target_ref)
    keys = ['mx', 'my', 'sx', 'sy', 'q1', 'q3']
    return sum((s[k] - t[k]) ** 2 for k in keys) / len(keys)


def wasserstein_1d(a, b):
    aa = sorted(a)
    bb = sorted(b)
    n = min(len(aa), len(bb))
    return sum(abs(aa[i] - bb[i]) for i in range(n)) / n


def sliced_wasserstein_2d(samples, target_ref, n_dirs=24, seed=999):
    rng = random.Random(seed)
    total = 0.0
    for _ in range(n_dirs):
        ang = rng.random() * 2.0 * math.pi
        ux = math.cos(ang)
        uy = math.sin(ang)
        proj_s = [ux * x + uy * y for x, y in samples]
        proj_t = [ux * x + uy * y for x, y in target_ref]
        total += wasserstein_1d(proj_s, proj_t)
    return total / n_dirs


def mode_center_mse(samples):
    # 指標3: 2モード中心への平均二乗距離(小さいほど中心に乗る)
    c1 = (-2.2, 1.6)
    c2 = (2.1, -1.5)
    err = 0.0
    for x, y in samples:
        d1 = (x - c1[0]) ** 2 + (y - c1[1]) ** 2
        d2 = (x - c2[0]) ** 2 + (y - c2[1]) ** 2
        err += min(d1, d2)
    return err / len(samples)


mse_lin = summary_stat_mse(gen_lin, target_data)
mse_trig = summary_stat_mse(gen_trig, target_data)
sw_lin = sliced_wasserstein_2d(gen_lin, target_data)
sw_trig = sliced_wasserstein_2d(gen_trig, target_data)
mc_lin = mode_center_mse(gen_lin)
mc_trig = mode_center_mse(gen_trig)

print('summary-stat MSE linear =', round(mse_lin, 6))
print('summary-stat MSE trig   =', round(mse_trig, 6))
print('sliced-W1 linear        =', round(sw_lin, 6))
print('sliced-W1 trig          =', round(sw_trig, 6))
print('mode-center MSE linear  =', round(mc_lin, 6))
print('mode-center MSE trig    =', round(mc_trig, 6))

lin_better = (mse_lin <= mse_trig) + (sw_lin <= sw_trig) + (mc_lin <= mc_trig)
if lin_better >= 2:
    print('linear path model is better on majority of toy metrics.')
else:
    print('trig path model is better on majority of toy metrics (or tied).')

time conditioning の有無を比べる

最後は with-tno-t を並べて、時刻を入れないと何が表現できなくなるのかを確認します。

# 失敗例: 時間情報 t を入力しないと何が起きるか
# 比較条件をできるだけ揃えるため、with-t側と同じ容量(16パラメータ)・同反復・同seed条件に近づける

def vf_features_no_t(x, y, t):
    # tを使わない代わりに、同容量化のため8次元特徴を使う
    return [1.0, x, y, x * x, y * y, x * y, abs(x), abs(y)]


def vf_predict_no_t(theta, x, y, t):
    f = vf_features_no_t(x, y, t)
    vx = sum(theta[i] * f[i] for i in range(8))
    vy = sum(theta[8 + i] * f[i] for i in range(8))
    return vx, vy


def train_no_t(objective='linear', steps=2600, batch_size=128, lr=0.012, seed=2026):
    rng = random.Random(seed)
    theta = [0.0] * 16
    bridge = linear_bridge if objective == 'linear' else trig_bridge

    for _ in range(steps + 1):
        grads = [0.0] * 16
        for _ in range(batch_size):
            x0 = noise_data[rng.randrange(len(noise_data))]
            x1 = target_data[rng.randrange(len(target_data))]
            t = rng.random()
            (xtx, xty), (utx, uty) = bridge(x0, x1, t)
            pvx, pvy = vf_predict_no_t(theta, xtx, xty, t)
            ex = pvx - utx
            ey = pvy - uty
            f = vf_features_no_t(xtx, xty, t)
            for i in range(8):
                grads[i] += 2.0 * ex * f[i]
                grads[8 + i] += 2.0 * ey * f[i]

        inv_bs = 1.0 / batch_size
        grads = [g * inv_bs for g in grads]

        gnorm = math.sqrt(sum(g * g for g in grads))
        if gnorm > 60.0:
            scale = 60.0 / gnorm
            grads = [g * scale for g in grads]

        wd = 2e-4
        theta = [theta[i] - lr * (grads[i] + wd * theta[i]) for i in range(16)]
    return theta


def integrate_no_t(theta, n=2600, n_steps=90, seed=505, clip=6.0):
    rng = random.Random(seed)
    xs = [(rng.gauss(0.0, 1.0), rng.gauss(0.0, 1.0)) for _ in range(n)]
    dt = 1.0 / n_steps
    for k in range(n_steps):
        t = k * dt
        nxt = []
        for x, y in xs:
            vx, vy = vf_predict_no_t(theta, x, y, t)
            xn = x + dt * vx
            yn = y + dt * vy
            xn = max(-clip, min(clip, xn))
            yn = max(-clip, min(clip, yn))
            nxt.append((xn, yn))
        xs = nxt
    return xs


def metric_triplet(samples):
    return (
        summary_stat_mse(samples, target_data),
        sliced_wasserstein_2d(samples, target_data),
        mode_center_mse(samples),
    )


def mean_std(vals):
    m = statistics.mean(vals)
    s = statistics.pstdev(vals)
    return m, s


theta_no_t = train_no_t(objective='linear')
eval_seeds = [501, 502, 503, 504, 505]
with_t_metrics = []
no_t_metrics = []

for sd in eval_seeds:
    g_with = integrate_ode(theta_lin, seed=sd)
    g_not = integrate_no_t(theta_no_t, seed=sd)
    with_t_metrics.append(metric_triplet(g_with))
    no_t_metrics.append(metric_triplet(g_not))

with_mse = [m[0] for m in with_t_metrics]
with_sw = [m[1] for m in with_t_metrics]
with_mc = [m[2] for m in with_t_metrics]
no_mse = [m[0] for m in no_t_metrics]
no_sw = [m[1] for m in no_t_metrics]
no_mc = [m[2] for m in no_t_metrics]

wm, ws = mean_std(with_mse)
nm, ns = mean_std(no_mse)
ww, wws = mean_std(with_sw)
nw, nws = mean_std(no_sw)
wc, wcs = mean_std(with_mc)
nc, ncs = mean_std(no_mc)

print('seeds =', eval_seeds)
print('summary-stat MSE with-t =', round(wm, 6), '+/-', round(ws, 6))
print('summary-stat MSE no-t   =', round(nm, 6), '+/-', round(ns, 6))
print('sliced-W1 with-t        =', round(ww, 6), '+/-', round(wws, 6))
print('sliced-W1 no-t          =', round(nw, 6), '+/-', round(nws, 6))
print('mode-center with-t      =', round(wc, 6), '+/-', round(wcs, 6))
print('mode-center no-t        =', round(nc, 6), '+/-', round(ncs, 6))

時間入力 t を消すと、同じ場所 x に対して全時刻で同じ速度しか出せません。これでは「初期は大きく動いて、終盤は微調整する」といった連続時間生成の本質を表現しにくくなります。

この比較では、with-tno-t でパラメータ数・学習反復数・評価seed群を可能な範囲で揃えています。それでも一般には最適化の偶然差や特徴設計差は残るため、最終判断は複数runで確認する運用が安全です。

実装上の含意として、Flow Matching では t の埋め込み設計(sin-cos 埋め込みやMLP)を丁寧に作る価値が高い、という点を押さえておけば十分です。