エネルギーベースモデル

エネルギーベースモデル(EBM)は、確率分布を「正規化された確率」ではなく「エネルギー関数」で表す考え方です。データらしい点には低いエネルギーを、データらしくない点には高いエネルギーを与えるように学習します。

このノートでは、1次元の最小実装で EBM の核を確認します。特に、SGLD と Replay Buffer を使った学習が何をしているかをコードで追います。

EBM は「確率を直接書く」代わりに、「らしさの地形」を学ぶ

データらしい点には低いエネルギー、らしくない点には高いエネルギーを与える。これが EBM の基本発想です。難しいのは、確率へ直すときに必要な正規化定数が高次元ではほとんど計算できないことです。

このノートでは、まず 1 次元の 2 モード分布を相手にエネルギー地形を置き、次に SGLD で負例を作って contrastive に更新します。そのあとで分配関数近似の荒さや Replay Buffer の役割を見て、EBM 学習の実感を作ります。

読み筋は「低いところへ寄せる」ではなく、「本物を下げて偽物を上げる」です

エネルギーを下げるという言い方だけでは、何を学習しているかがぼやけます。実際には、データ点のエネルギーを下げ、モデルが今出している負例のエネルギーを上げる対比で学習が進みます。

最初の山場は、負例をどう作るかです

分類の負例のように外から与えられるわけではないので、モデル自身の地形の上でサンプルを動かして作ります。SGLD が重要になるのはこのためです。

Replay Buffer は、過去の負例を捨てずに学習を安定させる

毎回ランダム初期化からサンプリングすると、負例分布が揺れすぎます。過去の負例を再利用すると、モデルがいま低エネルギーにしている領域を継続的に点検できます。

ここでは分配関数を厳密に計算するのでなく、難しさを体感する

EBM の教科書的な難所は Z です。この notebook では、それを完全に解くより「なぜ近似が要るのか」を掴むことを優先します。

ここでは EBM の難所を、1 次元でむき出しにする

高次元画像よりずっと簡単な設定ですが、その分だけ Z の扱いづらさ、負例生成の重要性、Replay Buffer の意味が見えやすくなっています。

まずはエネルギー地形の相手となる分布を置く

最初の節では、2 モード分布をデータとして用意し、学習がうまくいけばその山を再現できるかを見やすくします。

import math
import random
import statistics
import time

EBM の基本式は次です。

pθ(x)=exp(Eθ(x))Zθ,Zθ=exp(Eθ(x))dxp_\theta(x)=\frac{\exp(-E_\theta(x))}{Z_\theta},\qquad Z_\theta=\int \exp(-E_\theta(x))dx

ここで難しいのは分配関数 ZθZ_\theta です。高次元では厳密計算がほぼ不可能なので、学習では

という contrastive な差分で更新します。IGEBM系の実装でもこの考え方が中心です。

まず学習対象データを作ります。2モード混合分布を使い、モード構造を再現できるか確認しやすくします。

random.seed(31)

def sample_real(n: int):
    xs = []
    for _ in range(n):
        if random.random() < 0.55:
            xs.append(random.gauss(-2.0, 0.45))
        else:
            xs.append(random.gauss(1.8, 0.55))
    return xs


def describe(xs):
    left = sum(1 for x in xs if x < 0) / len(xs)
    right = 1.0 - left
    return {
        'mean': statistics.mean(xs),
        'std': statistics.pstdev(xs),
        'left': left,
        'right': right,
    }

real_data = sample_real(2400)
stats = describe(real_data)
print('dataset size =', len(real_data))
print('mean/std     =', round(stats['mean'], 4), round(stats['std'], 4))
print('left/right   =', round(stats['left'], 4), round(stats['right'], 4))

今回は 2 つの井戸(well)を持つエネルギー関数を使います。

a1,a2 は正である必要があるため a=exp(rawa)a=exp(raw_a) として実装します。k1,k2 は井戸の深さを調整する項で、モード比率(重み)を表現しやすくする役割があります。

def stable_logsumexp(a, b):
    m = a if a > b else b
    return m + math.log(math.exp(a - m) + math.exp(b - m))


def clip_raw_a(v):
    return max(-5.0, min(5.0, v))


def energy_value(x, theta):
    # theta = [raw_a1, m1, k1, raw_a2, m2, k2, c]
    raw_a1, m1, k1, raw_a2, m2, k2, c = theta
    a1 = math.exp(clip_raw_a(raw_a1))
    a2 = math.exp(clip_raw_a(raw_a2))

    e1 = a1 * (x - m1) ** 2 + k1
    e2 = a2 * (x - m2) ** 2 + k2
    return -stable_logsumexp(-e1, -e2) + c


def energy_and_grads(x, theta):
    raw_a1, m1, k1, raw_a2, m2, k2, c = theta
    raw_a1 = clip_raw_a(raw_a1)
    raw_a2 = clip_raw_a(raw_a2)
    a1 = math.exp(raw_a1)
    a2 = math.exp(raw_a2)

    d1 = x - m1
    d2 = x - m2
    e1 = a1 * d1 * d1 + k1
    e2 = a2 * d2 * d2 + k2

    u1 = math.exp(-e1)
    u2 = math.exp(-e2)
    den = u1 + u2
    w1 = u1 / den
    w2 = 1.0 - w1

    energy = -math.log(den) + c

    grad_x = w1 * (2.0 * a1 * d1) + w2 * (2.0 * a2 * d2)

    grad_raw_a1 = w1 * (d1 * d1) * a1
    grad_m1 = w1 * (-2.0 * a1 * d1)
    grad_k1 = w1

    grad_raw_a2 = w2 * (d2 * d2) * a2
    grad_m2 = w2 * (-2.0 * a2 * d2)
    grad_k2 = w2

    grad_c = 1.0

    grads = [grad_raw_a1, grad_m1, grad_k1, grad_raw_a2, grad_m2, grad_k2, grad_c]
    return energy, grad_x, grads


theta_demo = [0.1, -1.2, 0.0, 0.2, 1.3, 0.2, 0.0]
for x in [-2.0, -0.2, 1.5]:
    e, gx, gp = energy_and_grads(x, theta_demo)
    print('x=', x, 'E=', round(e, 5), 'dE/dx=', round(gx, 5), 'dE/dk1=', round(gp[2], 5))

次に SGLD で負例を作ります。

x <- x - step_size * dE/dx + noise_std * N(0,1)

-dE/dx はエネルギーを下げる方向なので、低エネルギー領域に集まりやすくなります。ノイズ項は探索性を保つ役割を持ちます。

class SampleReplayBuffer:
    def __init__(self, max_size=12000):
        self.max_size = max_size
        self.items = []

    def add(self, xs):
        self.items.extend(xs)
        if len(self.items) > self.max_size:
            self.items = self.items[-self.max_size :]

    def sample(self, n):
        if not self.items:
            return []
        if n >= len(self.items):
            return self.items[:]
        return random.sample(self.items, n)


def sgld_samples(theta, x_init, n_steps=40, step_size=0.06, noise_std=0.08):
    xs = x_init[:]
    for _ in range(n_steps):
        for i in range(len(xs)):
            _, grad_x, _ = energy_and_grads(xs[i], theta)
            xs[i] = xs[i] - step_size * grad_x + noise_std * random.gauss(0.0, 1.0)
            xs[i] = max(-6.0, min(6.0, xs[i]))
    return xs


buffer = SampleReplayBuffer(max_size=5000)
start = [random.uniform(-4.0, 4.0) for _ in range(8)]
sgld_out = sgld_samples(theta_demo, start, n_steps=30)
print('init  =', [round(v, 3) for v in start])
print('sgld  =', [round(v, 3) for v in sgld_out])
buffer.add(sgld_out)
print('buffer size =', len(buffer.items))

学習目的は

L(θ)=E_data[Eθ(x)] - E_neg[Eθ(x)] + λ * regularizer

です。これを最小化すると、データ近傍は低エネルギー、負例近傍は高エネルギーになります。

def mean_grads(xs, theta):
    g = [0.0] * 7
    e_sum = 0.0
    for x in xs:
        e, _, gp = energy_and_grads(x, theta)
        e_sum += e
        for i in range(7):
            g[i] += gp[i]
    n = len(xs)
    return e_sum / n, [v / n for v in g]


def train_ebm(theta_init, real_data, steps=380, batch_size=96, lr=0.016, reg=8e-4, replay_ratio=0.7):
    theta = theta_init[:]
    replay = SampleReplayBuffer(max_size=14000)
    history = []

    for step in range(steps + 1):
        x_data = random.sample(real_data, batch_size)

        n_replay = int(batch_size * replay_ratio)
        x_neg0 = replay.sample(n_replay)
        if len(x_neg0) < batch_size:
            x_neg0 = x_neg0 + [random.uniform(-4.5, 4.5) for _ in range(batch_size - len(x_neg0))]

        x_neg = sgld_samples(theta, x_neg0, n_steps=30, step_size=0.055, noise_std=0.085)
        replay.add(x_neg)

        e_data, g_data = mean_grads(x_data, theta)
        e_neg, g_neg = mean_grads(x_neg, theta)

        # L = E_data - E_neg + reg * (raw_a1^2 + raw_a2^2 + 0.2*(k1^2+k2^2))
        grads = [g_data[i] - g_neg[i] for i in range(7)]
        grads[0] += 2.0 * reg * theta[0]
        grads[3] += 2.0 * reg * theta[3]
        grads[2] += 2.0 * reg * 0.2 * theta[2]
        grads[5] += 2.0 * reg * 0.2 * theta[5]

        theta = [theta[i] - lr * grads[i] for i in range(7)]

        theta[0] = max(-3.0, min(3.0, theta[0]))
        theta[3] = max(-3.0, min(3.0, theta[3]))
        theta[1] = max(-4.0, min(4.0, theta[1]))
        theta[4] = max(-4.0, min(4.0, theta[4]))
        theta[2] = max(-3.0, min(3.0, theta[2]))
        theta[5] = max(-3.0, min(3.0, theta[5]))

        if step % 40 == 0:
            gap = e_data - e_neg
            history.append((step, e_data, e_neg, gap, theta[:], len(replay.items)))

    return theta, history, replay


random.seed(31)
theta0 = [0.0, -0.2, 0.0, 0.0, 0.2, 0.0, 0.0]
trained_theta, history, replay = train_ebm(theta0, real_data)

for step, e_d, e_n, gap, th, rs in history:
    print(
        f'step={step:03d}',
        f'E_data={round(e_d,4)}',
        f'E_neg={round(e_n,4)}',
        f'gap={round(gap,4)}',
        f'theta={[round(v,4) for v in th]}',
        f'replay={rs}'
    )

本物を下げ、負例を押し上げる

ここでは contrastive update を回して、real と model sample の統計がどう近づくかを見ます。

def draw_model_samples(theta, n=2600):
    init = [random.uniform(-4.5, 4.5) for _ in range(n)]
    return sgld_samples(theta, init, n_steps=60, step_size=0.05, noise_std=0.07)


random.seed(31)
model_samples = draw_model_samples(trained_theta)
real_stats = describe(real_data)
model_stats = describe(model_samples)

print('real  mean/std =', round(real_stats['mean'], 4), round(real_stats['std'], 4))
print('model mean/std =', round(model_stats['mean'], 4), round(model_stats['std'], 4))
print('real  left/right =', round(real_stats['left'], 4), round(real_stats['right'], 4))
print('model left/right =', round(model_stats['left'], 4), round(model_stats['right'], 4))
print('trained theta =', [round(v, 4) for v in trained_theta])

分配関数の近似がどれくらい荒いかを見る

Z を厳密に出せないことが、EBM の扱いづらさの核心です。ここでは近似密度がどれだけ粗いかを観察します。

def approx_partition(theta, x_min=-6.0, x_max=6.0, n_grid=5000):
    dx = (x_max - x_min) / n_grid
    s = 0.0
    for i in range(n_grid):
        x = x_min + (i + 0.5) * dx
        s += math.exp(-energy_value(x, theta))
    return s * dx


def approx_density(theta, x, z_est):
    return math.exp(-energy_value(x, theta)) / z_est


z_est = approx_partition(trained_theta)
print('approx Z =', round(z_est, 6))

probe_points = [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]
for x in probe_points:
    p = approx_density(trained_theta, x, z_est)
    print('x=', f'{x:>4.1f}', 'E=', round(energy_value(x, trained_theta), 4), 'p~', round(p, 6))

負例サンプリングが甘いと何が崩れるかを見る

SGLD のステップ不足や buffer なしの状態では、学習がどこで歪むのかを比較します。

# 学習失敗の典型例: SGLDステップ不足

def quick_train_with_sgld_steps(real_data, sgld_steps):
    random.seed(77 + sgld_steps)
    theta = [0.0, -0.1, 0.0, 0.0, 0.1, 0.0, 0.0]
    replay = SampleReplayBuffer(max_size=5000)

    for _ in range(140):
        x_data = random.sample(real_data, 64)
        x_neg0 = replay.sample(44)
        x_neg0 = x_neg0 + [random.uniform(-4.5, 4.5) for _ in range(64 - len(x_neg0))]
        x_neg = sgld_samples(theta, x_neg0, n_steps=sgld_steps, step_size=0.055, noise_std=0.085)
        replay.add(x_neg)

        _, g_data = mean_grads(x_data, theta)
        _, g_neg = mean_grads(x_neg, theta)
        grads = [g_data[i] - g_neg[i] for i in range(7)]
        theta = [theta[i] - 0.016 * grads[i] for i in range(7)]
        theta[0] = max(-3.0, min(3.0, theta[0]))
        theta[3] = max(-3.0, min(3.0, theta[3]))

    samples = draw_model_samples(theta, n=1200)
    return describe(samples)


stats_short = quick_train_with_sgld_steps(real_data, sgld_steps=5)
stats_long = quick_train_with_sgld_steps(real_data, sgld_steps=35)

print('short SGLD steps (5) :', {k: round(v, 4) for k, v in stats_short.items()})
print('long  SGLD steps (35):', {k: round(v, 4) for k, v in stats_long.items()})
print('real target          :', {k: round(v, 4) for k, v in describe(real_data).items()})

SGLDステップが少なすぎると、負例がモデル分布を十分に近似できず、エネルギー地形の更新が偏ります。実務では、SGLDステップ数、ノイズ強度、Replay比率をセットで調整します。

EBMの視点を持つと、スコアモデルや拡散モデルで出てくる「勾配で分布へ寄せる」考え方も同じ地図の上で理解しやすくなります。