スコアベースモデルと拡散モデル

このノートでは、スコアベースモデルと拡散モデルを「なぜその設計になるのか」から順に確認します。ポイントは、

  1. データ密度そのものではなく、log p(x) の勾配(スコア)を学ぶ
  2. そのスコア(または等価なノイズ予測)を使って逆方向にサンプリングする

という流れです。数式だけでなく、1次元データの最小実装で動作を確かめます。

拡散モデルは、いきなり画像を描くのでなく、ノイズから少しずつ戻る

難しい分布を一発で当てる代わりに、いったん壊してから戻す。この発想が拡散モデルの中心です。スコアベースモデルは、その戻る向きを log p(x) の勾配として表します。

このノートでは、まずスコアが何を向いている量かを確認し、次に DSM がどんな教師信号を与えているかを見ます。そのあとで DDPM 形式のノイズ予測へ移り、最後に逆過程のステップを減らしたときの品質劣化まで見ます。

読み筋は「密度そのもの」より「戻る向き」です

高次元で密度値をそのまま扱うのは難しいので、どちらへ動けば確率が上がるかという方向情報を学びます。後半では、この方向情報がノイズ予測とどうつながるかを見ます。

最初に 2 モード分布を使うのは、戻りきれているかを見やすくするためです

単峰分布だと、多少ずれても違いが目立ちません。2 つの山があると、逆過程がうまく働かないときの崩れ方がかなり分かりやすくなります。

DSM の要点は、ノイズを足した問題の方が学びやすいことです

元の密度を直接推定する代わりに、ノイズ付きデータから元へ戻る向きを学びます。この迂回が、拡散モデルを安定にしている発想の核です。

DDPM のノイズ予測は、スコアの別表現として読む

eps を当てているように見えても、やっていることは逆過程を進めるための方向情報の学習です。ここが見えると、スコアモデルと DDPM を別物として覚えなくて済みます。

最後は、局所誤差と最終品質が一致しないことを見る

ノイズ予測誤差が小さくても、ステップを粗くすると最終サンプルは崩れます。ここが拡散モデルの高速化を考える入口になります。

まずは壊してから戻す流れを見る

最初の節では、前向きノイズ化でモードがつぶれ、逆向きには予測器の質で戻り方が変わることを観察します。

import math
import random
import statistics
import time

スコアは

xlogp(x)\nabla_x \log p(x)

です。密度の絶対値より「どちらへ動くと密度が上がるか」を直接教えてくれる量なので、サンプリング更新に使いやすいのが利点です。

拡散モデルでは、前向き過程でデータにノイズを足し、逆過程でノイズを除去します。DDPM の前向き過程は離散時間で

q(xtxt1)=N(1βtxt1,βtI)q(x_t|x_{t-1})=\mathcal{N}(\sqrt{1-\beta_t}x_{t-1},\beta_t I)

と書け、さらに

xt=αˉtx0+1αˉtϵ,ϵN(0,I)x_t=\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\,\epsilon, \quad \epsilon\sim\mathcal{N}(0,I)

で直接サンプルできます。

まず学習対象として、2モードの1次元分布を作ります。1モードだと拡散の価値が見えにくいので、モード間の遷移が必要なデータを使います。

random.seed(37)

def sample_data(n: int):
    xs = []
    for _ in range(n):
        if random.random() < 0.58:
            xs.append(random.gauss(-2.1, 0.48))
        else:
            xs.append(random.gauss(1.9, 0.55))
    return xs


def describe(xs):
    left = sum(1 for x in xs if x < 0.0) / len(xs)
    right = 1.0 - left
    return {
        'mean': statistics.mean(xs),
        'std': statistics.pstdev(xs),
        'left': left,
        'right': right,
    }


data = sample_data(2600)
st = describe(data)
print('dataset size =', len(data))
print('mean/std     =', round(st['mean'], 4), round(st['std'], 4))
print('left/right   =', round(st['left'], 4), round(st['right'], 4))

次に「ノイズ付きデータでスコアを学ぶ」体験をします。Denoising Score Matching の教師信号は

target = -(x_t - x_0) / sigma^2

で、x_t = x_0 + sigma * eps のとき target = -eps/sigma になります。

なぜこれでスコアが学べるかというと、条件付き期待値をとると

E[target | x_t] = ∇_{x_t} log p_sigma(x_t)

が成り立つためです。つまり、各サンプルではノイジーな教師でも、平均的には「ノイズ付き分布のスコア」を向いています。ここが初学者が最初に躓きやすいポイントです。

# RBF特徴のスコアモデル: s_theta(x, sigma)
# 多項式より発散しにくく、初学者が挙動を追いやすい
SCORE_CENTERS = [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]
SCORE_RBF_SCALE = 1.2


def score_features(x, sigma):
    rbfs = [math.exp(-0.5 * ((x - c) / SCORE_RBF_SCALE) ** 2) for c in SCORE_CENTERS]
    return [1.0] + rbfs + [sigma] + [sigma * r for r in rbfs]


def score_predict(theta, x, sigma):
    feats = score_features(x, sigma)
    return sum(w * f for w, f in zip(theta, feats))


def dsm_batch(theta, x0_batch, sigma, rng):
    loss = 0.0
    grad = [0.0] * len(theta)
    n = len(x0_batch)
    for x0 in x0_batch:
        eps = rng.gauss(0.0, 1.0)
        xt = x0 + sigma * eps
        target = -(xt - x0) / (sigma * sigma)
        pred = score_predict(theta, xt, sigma)
        err = pred - target

        feats = score_features(xt, sigma)
        for i in range(len(theta)):
            grad[i] += 2.0 * err * feats[i]
        loss += err * err

    return loss / n, [g / n for g in grad]


def train_score_model(dataset, sigmas, steps=1800, batch_size=128, lr=0.004, seed=123):
    rng = random.Random(seed)
    n_feat = 1 + len(SCORE_CENTERS) + 1 + len(SCORE_CENTERS)
    theta = [0.0] * n_feat
    history = []

    for step in range(steps + 1):
        sigma = sigmas[rng.randrange(len(sigmas))]
        x0_batch = [dataset[rng.randrange(len(dataset))] for _ in range(batch_size)]
        loss, grad = dsm_batch(theta, x0_batch, sigma, rng)

        # 勾配クリップで安定化
        gnorm = math.sqrt(sum(g * g for g in grad))
        if gnorm > 20.0:
            scale = 20.0 / gnorm
            grad = [g * scale for g in grad]

        wd = 2e-4
        theta = [t - lr * (g + wd * t) for t, g in zip(theta, grad)]

        if step % 180 == 0:
            history.append((step, loss, sigma, theta[:]))

    return theta, history


sigmas = [0.2, 0.35, 0.55, 0.85, 1.1]
score_theta, score_hist = train_score_model(data, sigmas)

for step, loss, sigma, th in score_hist:
    print(f'step={step:04d}', 'sigma=', sigma, 'loss=', round(loss, 5), '||theta||=', round(math.sqrt(sum(v*v for v in th)), 4))

スコアがどちら向きを指しているかを見る

ここでは 1 次元の例で、スコアの符号がどちらへ戻せと言っているかを確かめます。

# 学習したスコアの向きを確認
# x を少し動かしたとき、score の符号が密度上昇方向と整合しているかを見る

def score_direction_demo(theta, sigma):
    probes = [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]
    for x in probes:
        s = score_predict(theta, x, sigma)
        direction = 'right(+)' if s > 0 else 'left(-)'
        print('x=', f'{x:>4.1f}', 'score=', round(s, 4), 'move=', direction)

print('score direction at sigma=0.35')
score_direction_demo(score_theta, sigma=0.35)

次に DDPM 形式へ移ります。ここではノイズ予測モデル eps_theta(x_t, t) を直接学習します。実際の U-Net の代わりに、時間ごとの低容量予測器

eps_hat = a_t * x_t + b_t + c_t * x_t^3 / 8

を使い、目的関数(ノイズMSE)と逆拡散式のつながりを確認します。

def make_beta_schedule(T, beta_start=0.0008, beta_end=0.12):
    betas = []
    for i in range(T):
        r = i / (T - 1) if T > 1 else 0.0
        betas.append(beta_start + (beta_end - beta_start) * r)
    return betas


def cumulative_product(vals):
    out = []
    c = 1.0
    for v in vals:
        c *= v
        out.append(c)
    return out


T = 50
betas = make_beta_schedule(T)
alphas = [1.0 - b for b in betas]
alpha_bars = cumulative_product(alphas)

print('beta[0], beta[-1] =', round(betas[0], 6), round(betas[-1], 6))
print('alpha_bar[0], alpha_bar[-1] =', round(alpha_bars[0], 6), round(alpha_bars[-1], 6))
print('note: alpha_bar[-1] が十分小さいので x_T はほぼガウスに近い')

次に、ノイズ予測器を学習する

DDPM 風の書き方へ移り、eps_theta(x_t, t) がどれくらいノイズを当てられるかを見ます。

def sample_xt(x0, t, rng):
    eps = rng.gauss(0.0, 1.0)
    ab = alpha_bars[t]
    xt = math.sqrt(ab) * x0 + math.sqrt(1.0 - ab) * eps
    return xt, eps
# eps predictor parameters per time step
# eps_hat = a_t * x_t + b_t + c_t * x_t^3 / 8
# 低容量ながら、単純線形より非線形性を少し持てるようにする
def init_eps_model(T):
    return [0.0] * T, [0.0] * T, [0.0] * T  # a_t, b_t, c_t
def train_eps_predictor(dataset, T, steps=5000, batch_size=128, lr=0.03, seed=777):
    rng = random.Random(seed)
    a, b, c = init_eps_model(T)
    history = []
    for step in range(steps + 1):
        t = rng.randrange(T)
        ga = 0.0
        gb = 0.0
        gc = 0.0
        loss = 0.0
        for _ in range(batch_size):
            x0 = dataset[rng.randrange(len(dataset))]
            xt, eps = sample_xt(x0, t, rng)
            x3 = (xt * xt * xt) / 8.0
            pred = a[t] * xt + b[t] + c[t] * x3
            err = pred - eps
            ga += 2.0 * err * xt
            gb += 2.0 * err
            gc += 2.0 * err * x3
            loss += err * err
        ga /= batch_size
        gb /= batch_size
        gc /= batch_size
        loss /= batch_size
        a[t] -= lr * ga
        b[t] -= lr * gb
        c[t] -= lr * gc
        # stability
        a[t] = max(-4.0, min(4.0, a[t]))
        b[t] = max(-4.0, min(4.0, b[t]))
        c[t] = max(-4.0, min(4.0, c[t]))
        if step % 350 == 0:
            history.append((step, loss, t, a[t], b[t], c[t]))
    return a, b, c, history
a_t, b_t, c_t, eps_hist = train_eps_predictor(data, T)
for step, loss, t, av, bv, cv in eps_hist:
    print(f'step={step:04d}', 'sampled t=', t, 'loss=', round(loss, 5), 'a_t=', round(av, 4), 'b_t=', round(bv, 4), 'c_t=', round(cv, 4))

逆過程を回してサンプルを戻す

学習した予測器を使って、ノイズからデータ側へ戻るサンプリングを実際に試します。

def eps_predict(xt, t, a, b, c):
    return a[t] * xt + b[t] + c[t] * (xt * xt * xt) / 8.0


def reverse_sample_ddpm(a, b, c, n=2600, seed=404):
    rng = random.Random(seed)
    samples = [rng.gauss(0.0, 1.0) for _ in range(n)]  # x_T

    for t in range(T - 1, -1, -1):
        bt = betas[t]
        at = alphas[t]
        ab = alpha_bars[t]

        new_samples = []
        for xt in samples:
            eps_hat = eps_predict(xt, t, a, b, c)
            mu = (xt - (bt / math.sqrt(max(1e-8, 1.0 - ab))) * eps_hat) / math.sqrt(at)

            if t > 0:
                z = rng.gauss(0.0, 1.0)
                sigma = math.sqrt(bt)
                x_prev = mu + sigma * z
            else:
                x_prev = mu

            x_prev = max(-6.0, min(6.0, x_prev))
            new_samples.append(x_prev)
        samples = new_samples

    return samples


gen_samples = reverse_sample_ddpm(a_t, b_t, c_t)
real_stats = describe(data)
gen_stats = describe(gen_samples)

print('real mean/std  =', round(real_stats['mean'], 4), round(real_stats['std'], 4))
print('gen  mean/std  =', round(gen_stats['mean'], 4), round(gen_stats['std'], 4))
print('real left/right=', round(real_stats['left'], 4), round(real_stats['right'], 4))
print('gen  left/right=', round(gen_stats['left'], 4), round(gen_stats['right'], 4))

ステップを間引いたときの崩れを見る

最後は高速化の代償として、粗い逆過程が左右モードや分散をどう崩すかを確認します。

# t ごとの予測品質を簡易評価
def eval_eps_mse(dataset, t, a, b, c, n=500, seed=888):
    rng = random.Random(seed + t)
    loss = 0.0
    for _ in range(n):
        x0 = dataset[rng.randrange(len(dataset))]
        xt, eps = sample_xt(x0, t, rng)
        pred = eps_predict(xt, t, a, b, c)
        loss += (pred - eps) ** 2
    return loss / n
for t in [0, 7, 14, 21, 28, 35, 42, 49]:
    print('t=', f'{t:>2}', 'mse=', round(eval_eps_mse(data, t, a_t, b_t, c_t), 5),
          'a_t=', round(a_t[t], 4), 'b_t=', round(b_t[t], 4), 'c_t=', round(c_t[t], 4))

ここまでで、次の対応関係が見えます。

つまり、モデルが学んでいる本質は「密度が高くなる方向」です。表現がスコアかノイズかの違いはありますが、目的は同じです。

このノートの予測器はあえて低容量なので、生成統計(特に分散)が実データより小さくなります。ここは失敗ではなく、容量制約がサンプル品質へどう影響するかを見るための設定です。

今回は alpha_bar[-1] を小さくするスケジュールを使い、x_T ~ N(0,1) で開始する逆過程との前提を揃えています。

# 失敗例: 逆過程ステップを間引くとどうなるか
def reverse_sample_with_skip(a, b, c, skip=1, n=2200, seed=909):
    rng = random.Random(seed)
    xs = [rng.gauss(0.0, 1.0) for _ in range(n)]
    t_values = list(range(T - 1, -1, -skip))
    if t_values[-1] != 0:
        t_values.append(0)
    for idx, t in enumerate(t_values):
        bt = betas[t]
        at = alphas[t]
        ab = alpha_bars[t]
        new_xs = []
        for xt in xs:
            eps_hat = eps_predict(xt, t, a, b, c)
            mu = (xt - (bt / math.sqrt(max(1e-8, 1.0 - ab))) * eps_hat) / math.sqrt(at)
            if t > 0:
                z = rng.gauss(0.0, 1.0)
                sigma = math.sqrt(bt)
                x_prev = mu + sigma * z
            else:
                x_prev = mu
            x_prev = max(-6.0, min(6.0, x_prev))
            new_xs.append(x_prev)
        xs = new_xs
    return xs
for skip in [1, 2, 5, 10]:
    smp = reverse_sample_with_skip(a_t, b_t, c_t, skip=skip)
    st = describe(smp)
    print('skip=', skip, 'mean/std=', round(st['mean'], 4), round(st['std'], 4), 'left/right=', round(st['left'], 4), round(st['right'], 4))
print('target left/right =', round(real_stats['left'], 4), round(real_stats['right'], 4))

eval_eps_mse がある程度下がっていても、逆過程を間引くと左右モード比率や分散が崩れることがあります。拡散モデルでは「局所のノイズ予測精度」と「最終サンプル品質」が完全には一致しない、という点をここで押さえてください。

逆過程を粗くすると生成品質が落ちやすいのは、各ステップで行う「少しずつの方向修正」を省略してしまうためです。実務では、ステップ数とサンプル品質のトレードオフを見ながら、DDIM などの高速化手法を選びます。

このノートで押さえるべき本質は、「拡散モデルはノイズを足す過程を作ることで、逆向き学習を安定化している」という点です。これが「なぜ拡散モデルが強いか」の出発点になります。