連続時間拡散モデルとフローマッチング理論
このノートでは、連続時間拡散と Flow Matching を同じ座標系で理解します。狙いは「式を覚える」ことではなく、
- どんな中間経路(path)を選ぶか
- その経路に沿う速度場(vector field)をどう学習するか
- その速度場を積分して、どう生成するか
を一貫して追えるようにすることです。
Flow Matching は、途中の「正しい動き方」を直接学ぶ
拡散モデルではノイズ除去として逆過程を組み立てますが、Flow Matching ではノイズ点からデータ点へ至る経路上の速度場そのものを学びます。何を予測するかが違うだけで、連続時間生成としてはかなり近い世界にいます。
このノートでは、まずノイズとデータを結ぶ path を定義し、次にその path に沿う目標速度 を回帰します。そのあとで time conditioning の意味、path の選び方、OT-CFM が何を工夫しているかまでを見て、連続時間生成の見取り図を作ります。
読み筋は「どの道を通るか」と「その道でどう動くか」です
Flow Matching の設計自由度は、path と vector field の 2 箇所にあります。ここが見えると、線形ブリッジと別のブリッジの違いも、単なる式差分ではなく学習対象の違いとして読めます。
最初に path を並べるのは、生成器より経路設計が主役だからです
ノイズからデータへ行くなら何でもよいわけではありません。どの中間点を通るかで、学習しやすさも速度場の形も変わります。
速度場を学ぶとは、その時刻でどちらへ動くかを当てること
スコアベースモデルが確率上昇方向を学ぶのに対し、Flow Matching は経路に沿った目標速度を直接回帰します。ここがこの notebook のいちばん大事な違いです。
時刻入力が必要なのは、同じ場所でも「いま何時か」で動き方が変わるから
連続時間生成では、初期と終盤で望ましい速度が違います。t を消すと、その違いを表せなくなります。
OT-CFM は loss を変えるより、ペアの作り方を変える
ここは初学者が誤解しやすい点です。何を回帰するかは同じでも、どの (x_0, x_1) を結びつけるかを工夫すると、学習しやすい flow が作れるようになります。
まずはノイズ点とデータ点を橋で結ぶ
最初の節では、ノイズ源とターゲット分布を置き、線形 path と別の path がどう違うかを可視化します。
import math
import random
import statistics
import time
Flow Matching の基本は、ノイズ点 とデータ点 を結ぶ経路 を設計し、対応する目標速度 を回帰することです。
最初に記号をそろえます。
- : 時刻
- : 初期ノイズ点
- : データ点
- : 時刻 の中間点
- : 経路が持つ教師速度
- : ニューラルネットが出す予測速度
- : 生成開始時点()の初期分布
学習目的は
です。二乗誤差回帰の一般事実として、各点 における最適予測は条件付き期待値になります。したがって最適解は
であり、「各ペア速度をそのまま暗記」ではなく「同じ に集まる速度の平均場」を学習している、と読むのが正確です。
学習後は ODE
を積分して生成します。
連続時間拡散モデルとの接続(上級・初回は読み飛ばしてもOK):
前向き過程を SDE
で定めたとき、ここで
- : ブラウン運動(連続時間のランダム揺らぎ)
- : 時刻 の分布密度
- score : 密度の対数勾配(高密度方向を指す量)
です。代表的な前提(例えば が状態非依存の設定)では確率流 ODE を
と書けます。拡散側は score を経由して速度場を得るのに対し、Flow Matching は経路と教師速度を直接設計して速度場を学習します。
注意: 拡散文献では時間向きを data→noise で書くことが多いのに対し、このノートは Flow Matching の実装導線に合わせて noise→data 向きで記述しています。
注意: このノートでは x_0=ノイズ, x_1=データ と置きます。拡散文献の一部では をデータと書くため、添字の意味を都度確認してください。
まず2次元データを用意します。ノイズ源は標準ガウス、ターゲットは2峰の2次元混合ガウスです。
random.seed(41)
def sample_noise_2d(n: int):
out = []
for _ in range(n):
out.append((random.gauss(0.0, 1.0), random.gauss(0.0, 1.0)))
return out
def sample_data_2d(n: int):
out = []
for _ in range(n):
if random.random() < 0.5:
x = random.gauss(-2.2, 0.45)
y = random.gauss(1.6, 0.55)
else:
x = random.gauss(2.1, 0.5)
y = random.gauss(-1.5, 0.5)
out.append((x, y))
return out
def describe_2d(xs):
mx = statistics.mean(x for x, _ in xs)
my = statistics.mean(y for _, y in xs)
sx = statistics.pstdev(x for x, _ in xs)
sy = statistics.pstdev(y for _, y in xs)
q1 = sum(1 for x, y in xs if x < 0 and y >= 0) / len(xs)
q3 = sum(1 for x, y in xs if x >= 0 and y < 0) / len(xs)
return {'mx': mx, 'my': my, 'sx': sx, 'sy': sy, 'q1': q1, 'q3': q3}
noise_data = sample_noise_2d(3200)
target_data = sample_data_2d(3200)
print('noise stats =', {k: round(v, 4) for k, v in describe_2d(noise_data).items()})
print('target stats=', {k: round(v, 4) for k, v in describe_2d(target_data).items()})
次に2種類のパスを定義します。
- 線形パス(独立カップリング版の CFM)
x_t = (1-t)x_0 + t x_1- 速度は定数
u_t = x_1 - x_0
- 三角関数パス(連続時間拡散のノイズ減衰を意識した滑らかな経路)
x_t = alpha(t) x_1 + sigma(t) x_0- ここで
alpha(t)=sin(pi t/2),sigma(t)=cos(pi t/2)
DDPM 的な離散式 x_t = sqrt(alpha_bar_t) x_1 + sqrt(1-alpha_bar_t) x_0 を連続時間で見ると、「データ係数を増やしノイズ係数を減らす」経路設計になります。このノートの三角関数パスは、その直観を連続時間で観察しやすくした簡易版です。
係数の違いも押さえておきます。線形は alpha+sigma=1、三角関数は です。例えば t=0.5 では線形が (0.5, 0.5)、三角関数が (0.707..., 0.707...) なので、中間の混ざり方が異なります。
注意: ここでの線形パスは x_0, x_1 を独立に引いた組を使っており、厳密な OT カップリングではありません。
参考資料で扱う OT-CFM(Optimal-Transport Conditional Flow Matching)では、損失関数の形自体は同じでも、学習に使う の組を最適輸送カップリングで作る点が重要です。
直観的には、ノイズ側の点とデータ側の点を「できるだけ無理のない対応」で結び、学習すべきベクトル場を素直にする工夫です。ミニバッチ単位で近似 OT を解く実装が一般的で、画像規模では計算コストと近似精度のトレードオフ設計が実務上のポイントになります。
def linear_bridge(x0, x1, t):
x = (1.0 - t) * x0[0] + t * x1[0]
y = (1.0 - t) * x0[1] + t * x1[1]
vx = x1[0] - x0[0]
vy = x1[1] - x0[1]
return (x, y), (vx, vy)
def trig_bridge(x0, x1, t):
# alpha(0)=0, sigma(0)=1; alpha(1)=1, sigma(1)=0
alpha = math.sin(0.5 * math.pi * t)
sigma = math.cos(0.5 * math.pi * t)
alpha_dot = 0.5 * math.pi * math.cos(0.5 * math.pi * t)
sigma_dot = -0.5 * math.pi * math.sin(0.5 * math.pi * t)
x = alpha * x1[0] + sigma * x0[0]
y = alpha * x1[1] + sigma * x0[1]
vx = alpha_dot * x1[0] + sigma_dot * x0[0]
vy = alpha_dot * x1[1] + sigma_dot * x0[1]
return (x, y), (vx, vy)
def path_length_demo(n=600):
# 同じカップル(x0,x1)に対し、2パスの移動距離を比較
total_linear = 0.0
total_trig = 0.0
for _ in range(n):
x0 = noise_data[random.randrange(len(noise_data))]
x1 = target_data[random.randrange(len(target_data))]
# linear length (closed form)
dl = math.sqrt((x1[0] - x0[0]) ** 2 + (x1[1] - x0[1]) ** 2)
# trig length (numerical)
prev, _ = trig_bridge(x0, x1, 0.0)
dt = 1.0 / 120
s = 0.0
for i in range(1, 121):
t = i * dt
cur, _ = trig_bridge(x0, x1, t)
s += math.sqrt((cur[0] - prev[0]) ** 2 + (cur[1] - prev[1]) ** 2)
prev = cur
total_linear += dl
total_trig += s
return total_linear / n, total_trig / n
lin_len, trig_len = path_length_demo()
print('avg path length linear =', round(lin_len, 4))
print('avg path length trig =', round(trig_len, 4))
print('ratio trig/linear =', round(trig_len / lin_len, 4))
速度場モデル は、2次元入力 + 時刻に対する低容量回帰器にします。U-Netではなく軽量モデルを使うのは、目的が理論理解だからです。
この設定でも、Flow Matching の学習が「ODEのロールアウトなし」にできることは確認できます。
# feature map for vector field regression
# phi(x,t) -> 8 features
def vf_features(x, y, t):
return [1.0, x, y, t, x * t, y * t, x * x, y * y]
def vf_predict(theta, x, y, t):
# theta has 16 params: first 8 for vx, last 8 for vy
f = vf_features(x, y, t)
vx = sum(theta[i] * f[i] for i in range(8))
vy = sum(theta[8 + i] * f[i] for i in range(8))
return vx, vy
def train_flow_matching(objective='linear', steps=2600, batch_size=128, lr=0.012, seed=2026):
rng = random.Random(seed)
theta = [0.0] * 16
history = []
bridge = linear_bridge if objective == 'linear' else trig_bridge
for step in range(steps + 1):
grads = [0.0] * 16
loss = 0.0
for _ in range(batch_size):
x0 = noise_data[rng.randrange(len(noise_data))]
x1 = target_data[rng.randrange(len(target_data))]
t = rng.random()
(xtx, xty), (utx, uty) = bridge(x0, x1, t)
pvx, pvy = vf_predict(theta, xtx, xty, t)
ex = pvx - utx
ey = pvy - uty
f = vf_features(xtx, xty, t)
for i in range(8):
grads[i] += 2.0 * ex * f[i]
grads[8 + i] += 2.0 * ey * f[i]
loss += ex * ex + ey * ey
inv_bs = 1.0 / batch_size
grads = [g * inv_bs for g in grads]
loss *= inv_bs
# gradient clip
gnorm = math.sqrt(sum(g * g for g in grads))
if gnorm > 60.0:
scale = 60.0 / gnorm
grads = [g * scale for g in grads]
wd = 2e-4
theta = [theta[i] - lr * (grads[i] + wd * theta[i]) for i in range(16)]
if step % 260 == 0:
history.append((step, loss, math.sqrt(sum(v * v for v in theta))))
return theta, history
theta_lin, hist_lin = train_flow_matching(objective='linear')
theta_trig, hist_trig = train_flow_matching(objective='trig')
print('linear objective history:')
for step, loss, norm in hist_lin:
print('step=', f'{step:04d}', 'loss=', round(loss, 5), '||theta||=', round(norm, 4))
print()
print('trig objective history:')
for step, loss, norm in hist_trig:
print('step=', f'{step:04d}', 'loss=', round(loss, 5), '||theta||=', round(norm, 4))
path に沿う速度場を学習する
ここでは を回帰し、loss とパラメータがどう動くかを見ます。
def integrate_ode(theta, n=2600, n_steps=90, seed=505, clip=6.0):
rng = random.Random(seed)
xs = [(rng.gauss(0.0, 1.0), rng.gauss(0.0, 1.0)) for _ in range(n)]
dt = 1.0 / n_steps
for k in range(n_steps):
t = k * dt
nxt = []
for x, y in xs:
vx, vy = vf_predict(theta, x, y, t)
xn = x + dt * vx
yn = y + dt * vy
# 数値爆発を防ぐためのガード。理論上の流れそのものではなく数値安定化の処理。
xn = max(-clip, min(clip, xn))
yn = max(-clip, min(clip, yn))
nxt.append((xn, yn))
xs = nxt
return xs
gen_lin = integrate_ode(theta_lin)
gen_trig = integrate_ode(theta_trig)
print('target stats =', {k: round(v, 4) for k, v in describe_2d(target_data).items()})
print('gen linear =', {k: round(v, 4) for k, v in describe_2d(gen_lin).items()})
print('gen trig =', {k: round(v, 4) for k, v in describe_2d(gen_trig).items()})
生成結果を path ごとに比べる
学習後は、どの path を選んだかでサンプル統計がどう変わるかを見ます。
def summary_stat_mse(samples, target_ref):
# 指標1: 1次/2次統計と象限比率の差分をまとめた粗いMSE
s = describe_2d(samples)
t = describe_2d(target_ref)
keys = ['mx', 'my', 'sx', 'sy', 'q1', 'q3']
return sum((s[k] - t[k]) ** 2 for k in keys) / len(keys)
def wasserstein_1d(a, b):
aa = sorted(a)
bb = sorted(b)
n = min(len(aa), len(bb))
return sum(abs(aa[i] - bb[i]) for i in range(n)) / n
def sliced_wasserstein_2d(samples, target_ref, n_dirs=24, seed=999):
rng = random.Random(seed)
total = 0.0
for _ in range(n_dirs):
ang = rng.random() * 2.0 * math.pi
ux = math.cos(ang)
uy = math.sin(ang)
proj_s = [ux * x + uy * y for x, y in samples]
proj_t = [ux * x + uy * y for x, y in target_ref]
total += wasserstein_1d(proj_s, proj_t)
return total / n_dirs
def mode_center_mse(samples):
# 指標3: 2モード中心への平均二乗距離(小さいほど中心に乗る)
c1 = (-2.2, 1.6)
c2 = (2.1, -1.5)
err = 0.0
for x, y in samples:
d1 = (x - c1[0]) ** 2 + (y - c1[1]) ** 2
d2 = (x - c2[0]) ** 2 + (y - c2[1]) ** 2
err += min(d1, d2)
return err / len(samples)
mse_lin = summary_stat_mse(gen_lin, target_data)
mse_trig = summary_stat_mse(gen_trig, target_data)
sw_lin = sliced_wasserstein_2d(gen_lin, target_data)
sw_trig = sliced_wasserstein_2d(gen_trig, target_data)
mc_lin = mode_center_mse(gen_lin)
mc_trig = mode_center_mse(gen_trig)
print('summary-stat MSE linear =', round(mse_lin, 6))
print('summary-stat MSE trig =', round(mse_trig, 6))
print('sliced-W1 linear =', round(sw_lin, 6))
print('sliced-W1 trig =', round(sw_trig, 6))
print('mode-center MSE linear =', round(mc_lin, 6))
print('mode-center MSE trig =', round(mc_trig, 6))
lin_better = (mse_lin <= mse_trig) + (sw_lin <= sw_trig) + (mc_lin <= mc_trig)
if lin_better >= 2:
print('linear path model is better on majority of toy metrics.')
else:
print('trig path model is better on majority of toy metrics (or tied).')
time conditioning の有無を比べる
最後は with-t と no-t を並べて、時刻を入れないと何が表現できなくなるのかを確認します。
# 失敗例: 時間情報 t を入力しないと何が起きるか
# 比較条件をできるだけ揃えるため、with-t側と同じ容量(16パラメータ)・同反復・同seed条件に近づける
def vf_features_no_t(x, y, t):
# tを使わない代わりに、同容量化のため8次元特徴を使う
return [1.0, x, y, x * x, y * y, x * y, abs(x), abs(y)]
def vf_predict_no_t(theta, x, y, t):
f = vf_features_no_t(x, y, t)
vx = sum(theta[i] * f[i] for i in range(8))
vy = sum(theta[8 + i] * f[i] for i in range(8))
return vx, vy
def train_no_t(objective='linear', steps=2600, batch_size=128, lr=0.012, seed=2026):
rng = random.Random(seed)
theta = [0.0] * 16
bridge = linear_bridge if objective == 'linear' else trig_bridge
for _ in range(steps + 1):
grads = [0.0] * 16
for _ in range(batch_size):
x0 = noise_data[rng.randrange(len(noise_data))]
x1 = target_data[rng.randrange(len(target_data))]
t = rng.random()
(xtx, xty), (utx, uty) = bridge(x0, x1, t)
pvx, pvy = vf_predict_no_t(theta, xtx, xty, t)
ex = pvx - utx
ey = pvy - uty
f = vf_features_no_t(xtx, xty, t)
for i in range(8):
grads[i] += 2.0 * ex * f[i]
grads[8 + i] += 2.0 * ey * f[i]
inv_bs = 1.0 / batch_size
grads = [g * inv_bs for g in grads]
gnorm = math.sqrt(sum(g * g for g in grads))
if gnorm > 60.0:
scale = 60.0 / gnorm
grads = [g * scale for g in grads]
wd = 2e-4
theta = [theta[i] - lr * (grads[i] + wd * theta[i]) for i in range(16)]
return theta
def integrate_no_t(theta, n=2600, n_steps=90, seed=505, clip=6.0):
rng = random.Random(seed)
xs = [(rng.gauss(0.0, 1.0), rng.gauss(0.0, 1.0)) for _ in range(n)]
dt = 1.0 / n_steps
for k in range(n_steps):
t = k * dt
nxt = []
for x, y in xs:
vx, vy = vf_predict_no_t(theta, x, y, t)
xn = x + dt * vx
yn = y + dt * vy
xn = max(-clip, min(clip, xn))
yn = max(-clip, min(clip, yn))
nxt.append((xn, yn))
xs = nxt
return xs
def metric_triplet(samples):
return (
summary_stat_mse(samples, target_data),
sliced_wasserstein_2d(samples, target_data),
mode_center_mse(samples),
)
def mean_std(vals):
m = statistics.mean(vals)
s = statistics.pstdev(vals)
return m, s
theta_no_t = train_no_t(objective='linear')
eval_seeds = [501, 502, 503, 504, 505]
with_t_metrics = []
no_t_metrics = []
for sd in eval_seeds:
g_with = integrate_ode(theta_lin, seed=sd)
g_not = integrate_no_t(theta_no_t, seed=sd)
with_t_metrics.append(metric_triplet(g_with))
no_t_metrics.append(metric_triplet(g_not))
with_mse = [m[0] for m in with_t_metrics]
with_sw = [m[1] for m in with_t_metrics]
with_mc = [m[2] for m in with_t_metrics]
no_mse = [m[0] for m in no_t_metrics]
no_sw = [m[1] for m in no_t_metrics]
no_mc = [m[2] for m in no_t_metrics]
wm, ws = mean_std(with_mse)
nm, ns = mean_std(no_mse)
ww, wws = mean_std(with_sw)
nw, nws = mean_std(no_sw)
wc, wcs = mean_std(with_mc)
nc, ncs = mean_std(no_mc)
print('seeds =', eval_seeds)
print('summary-stat MSE with-t =', round(wm, 6), '+/-', round(ws, 6))
print('summary-stat MSE no-t =', round(nm, 6), '+/-', round(ns, 6))
print('sliced-W1 with-t =', round(ww, 6), '+/-', round(wws, 6))
print('sliced-W1 no-t =', round(nw, 6), '+/-', round(nws, 6))
print('mode-center with-t =', round(wc, 6), '+/-', round(wcs, 6))
print('mode-center no-t =', round(nc, 6), '+/-', round(ncs, 6))
時間入力 t を消すと、同じ場所 x に対して全時刻で同じ速度しか出せません。これでは「初期は大きく動いて、終盤は微調整する」といった連続時間生成の本質を表現しにくくなります。
この比較では、with-t と no-t でパラメータ数・学習反復数・評価seed群を可能な範囲で揃えています。それでも一般には最適化の偶然差や特徴設計差は残るため、最終判断は複数runで確認する運用が安全です。
実装上の含意として、Flow Matching では t の埋め込み設計(sin-cos 埋め込みやMLP)を丁寧に作る価値が高い、という点を押さえておけば十分です。