GAN

GAN(Generative Adversarial Network)は、生成器 G と識別器 D を競わせることでデータ分布を学習する手法です。ここでは1次元データを使い、数式とコードを行き来しながら、学習がうまくいく条件と失敗する条件を確認します。

GAN は、生成器を直接採点できないから識別器を立てる

生成器だけを見ても「どれだけ本物に近いか」を測りづらいので、対戦相手として識別器を置きます。GAN の本質は、この採点器と生成器を同時に更新することで、分布の差を間接的に埋める点にあります。

このノートでは、まず 2 峰の 1 次元分布を相手にして GAN の対戦構造を見ます。そのあとで minimax と non-saturating の差、表現力不足による mode collapse、LSGAN や WGAN が何を直したいのかへ進みます。

読み筋は「対戦の構造」と「失敗の構造」の 2 本です

損失関数だけ追っても、GAN は分かりにくいモデルです。生成器がそもそも分布を表せるか、識別器が強すぎないか、勾配が枯れていないか。この notebook では失敗の出方を先に意識して読みます。

最初に 2 峰分布を使うのは、mode collapse を見つけやすいからです

単峰データだと生成器が一部しか出していない失敗に気づきにくくなります。左右 2 モードを用意すると、多様性が落ちたときの異常がかなり分かりやすくなります。

non-saturating が必要になるのは、初期の勾配が弱すぎるから

GAN の古典的な難所は、識別器がすぐ勝ちすぎて生成器が学びにくくなることです。後半では損失の違いを、勾配の強さとして見ます。

WGAN や LSGAN は「別損失」ではなく、失敗モードへの対策として読む

単に名前を覚えるより、どの不安定さを減らしたくて設計されたのかを見た方が理解しやすくなります。

ここでは GAN の不安定さを、できるだけ小さく見える形にしている

深いネットワークのテクニックを盛る前に、対戦構造そのものが何を良くし、どこで壊れやすいのかを読むことを優先しています。

まずは対戦前の分布を置く

最初の節では、実データ分布と初期生成分布がどれくらい離れているかを確認します。対戦の出発点です。

import math
import random
import statistics

実データ分布を pdata(x)p_data(x)、生成分布を pg(x)p_g(x) と書くと、目標は p_g ≈ p_data です。GANでは D が「本物か偽物か」を見分ける能力を上げ、G がその D をだます能力を上げます。

この設計の利点は、尤度を直接書きにくい問題でも、識別問題として学習を回せることです。一方で、学習は不安定になりやすく、目的関数とモデル設計の意図を理解していないと改善が難しくなります。

まず、2つの山を持つ1次元分布を実データとして作ります。2峰性のデータを使う理由は、mode collapse(片方の山しか出せなくなる失敗)を検出しやすくするためです。

random.seed(21)

def sample_real(n: int):
    out = []
    for _ in range(n):
        if random.random() < 0.5:
            out.append(random.gauss(-2.0, 0.35))
        else:
            out.append(random.gauss(2.0, 0.35))
    return out

real_preview = sample_real(2000)
print('real mean =', round(statistics.mean(real_preview), 4))
print('real stdev =', round(statistics.pstdev(real_preview), 4))
print('left ratio =', round(sum(1 for x in real_preview if x < 0) / len(real_preview), 4))
print('right ratio=', round(sum(1 for x in real_preview if x >= 0) / len(real_preview), 4))

GANの代表的な目的関数は次です。

maxD  Expdata[logD(x)]+Ezp(z)[log(1D(G(z)))]\max_D\; \mathbb{E}_{x\sim p_{data}}[\log D(x)] + \mathbb{E}_{z\sim p(z)}[\log(1-D(G(z)))] maxG  Ezp(z)[logD(G(z))](non-saturating 版)\max_G\; \mathbb{E}_{z\sim p(z)}[\log D(G(z))]\quad\text{(non-saturating 版)}

教科書で出る minGE[log(1D(G(z)))]\min_G \mathbb{E}[\log(1-D(G(z)))](minimax 版)は、初期段階で勾配が弱くなりやすいので、実装では non-saturating 版がよく使われます。

def sigmoid(t: float) -> float:
    if t >= 0:
        e = math.exp(-t)
        return 1.0 / (1.0 + e)
    e = math.exp(t)
    return e / (1.0 + e)


def clamp_prob(p: float, eps: float = 1e-8) -> float:
    return min(1.0 - eps, max(eps, p))


def sample_z(n: int):
    return [random.gauss(0.0, 1.0) for _ in range(n)]


# 線形の最小モデル
# G(z) = a z + b
# D(x) = sigmoid(c x + d)
def generator_linear(z: float, theta):
    a, b = theta
    return a * z + b


def discriminator(x: float, phi):
    c, d = phi
    return sigmoid(c * x + d)


def losses_on_batch(theta, phi, x_real, z_batch, generator_fn):
    x_fake = [generator_fn(z, theta) for z in z_batch]

    d_real = [clamp_prob(discriminator(x, phi)) for x in x_real]
    d_fake = [clamp_prob(discriminator(x, phi)) for x in x_fake]

    ld = sum(math.log(p) for p in d_real) / len(d_real) + sum(math.log(1.0 - p) for p in d_fake) / len(d_fake)
    lg_nonsat = sum(math.log(p) for p in d_fake) / len(d_fake)          # maximize
    lg_minimax = sum(math.log(1.0 - p) for p in d_fake) / len(d_fake)   # minimize

    return ld, lg_nonsat, lg_minimax, x_fake

生成器と識別器を交互に更新する

ここでは実際に学習を回し、損失だけでなく Wasserstein 的な距離や左右モード比率も一緒に見ます。

theta0 = [0.15, 0.0]
phi0 = [0.25, 0.0]

x_real = sample_real(256)
z_batch = sample_z(256)
ld0, lg0, lg_min0, x_fake0 = losses_on_batch(theta0, phi0, x_real, z_batch, generator_linear)

print('initial L_D (maximize)        =', round(ld0, 5))
print('initial L_G non-saturating    =', round(lg0, 5))
print('initial L_G minimax objective =', round(lg_min0, 5))
print('fake preview mean/stdev =', round(statistics.mean(x_fake0), 4), round(statistics.pstdev(x_fake0), 4))

次に学習を回します。今回は自動微分ではなく有限差分を使います。深層学習の本番では自動微分を使うべきですが、有限差分だと「どの目的を最大化・最小化しているか」が見えやすく、概念理解には有効です。

このノートでは LDL_DLGL_G をどちらも「最大化」する書き方に統一しているため、更新式は params = params + lr * grad になります。一般的な最小化実装(params - lr * grad)と符号が逆に見えるのは、最適化している向きが違うためです。

ログに出す W1 は実分布と生成分布の距離なので、小さいほど良い値です。leftright は左右モードの比率で、両者が 0.5 付近なら mode collapse が起きにくい状態と解釈できます。

def finite_diff_grad(fn, params, h: float = 1e-4):
    grads = []
    for i in range(len(params)):
        plus = params[:]
        minus = params[:]
        plus[i] += h
        minus[i] -= h
        grads.append((fn(plus) - fn(minus)) / (2.0 * h))
    return grads
def empirical_w1_1d(xs, ys):
    n = min(len(xs), len(ys))
    xs_sorted = sorted(xs)[:n]
    ys_sorted = sorted(ys)[:n]
    return sum(abs(a - b) for a, b in zip(xs_sorted, ys_sorted)) / n
def evaluate_distribution(theta, generator_fn, n=4000):
    real = sample_real(n)
    fake = [generator_fn(z, theta) for z in sample_z(n)]
    left = sum(1 for x in fake if x < 0) / len(fake)
    right = 1.0 - left
    return {
        'real_mean': statistics.mean(real),
        'real_std': statistics.pstdev(real),
        'fake_mean': statistics.mean(fake),
        'fake_std': statistics.pstdev(fake),
        'w1': empirical_w1_1d(real, fake),
        'left': left,
        'right': right,
    }
def train_gan(theta_init, phi_init, generator_fn, steps=320, batch_size=128, lr_g=0.04, lr_d=0.04, d_updates=2, log_every=40):
    theta = theta_init[:]
    phi = phi_init[:]
    history = []
    for step in range(steps + 1):
        for _ in range(d_updates):
            x_real = sample_real(batch_size)
            z = sample_z(batch_size)
            def d_objective(phi_try):
                ld, _, _, _ = losses_on_batch(theta, phi_try, x_real, z, generator_fn)
                return ld
            g_phi = finite_diff_grad(d_objective, phi)
            # L_D を最大化するので、勾配上昇(+)で更新
            phi = [p + lr_d * gp for p, gp in zip(phi, g_phi)]
        x_real = sample_real(batch_size)
        z = sample_z(batch_size)
        def g_objective(theta_try):
            _, lg, _, _ = losses_on_batch(theta_try, phi, x_real, z, generator_fn)
            return lg
        g_theta = finite_diff_grad(g_objective, theta)
        # L_G(non-saturating)を最大化するので、勾配上昇(+)で更新
        theta = [t + lr_g * gt for t, gt in zip(theta, g_theta)]
        if step % log_every == 0:
            x_eval = sample_real(1000)
            z_eval = sample_z(1000)
            ld_eval, lg_eval, _, _ = losses_on_batch(theta, phi, x_eval, z_eval, generator_fn)
            stats = evaluate_distribution(theta, generator_fn, n=2000)
            history.append((step, theta[:], phi[:], ld_eval, lg_eval, stats['w1'], stats['left'], stats['right']))
    return theta, phi, history

表現力不足だけでも collapse は起きる

学習が悪いから崩れるだけではありません。線形生成器のように、そもそも 2 峰を表せない表現では collapse は避けられません。

random.seed(21)
lin_theta, lin_phi, lin_history = train_gan(theta0, phi0, generator_linear)

for step, th, ph, ld, lg, w1, left, right in lin_history:
    print(
        f'step={step:03d}',
        f'theta={[round(v,4) for v in th]}',
        f'phi={[round(v,4) for v in ph]}',
        f'L_D={round(ld,4)}',
        f'L_G={round(lg,4)}',
        f'W1={round(w1,4)}',
        f'left={round(left,3)}',
        f'right={round(right,3)}'
    )

lin_stats = evaluate_distribution(lin_theta, generator_linear)
print()
print('linear generator final stats:')
for k, v in lin_stats.items():
    print(k, '=', round(v, 4))

ここで重要なのは、学習アルゴリズム以前に表現力の限界があることです。z~N(0,1) に対して線形写像 a z + b を使うと、pgp_g は必ず単峰ガウスになります。したがって、2峰の実分布を正確には表現できません。

つまり「学習が遅い」だけでなく、「モデルがその分布族を持っていない」ことが失敗の根本原因です。

# 2分岐生成器: z<0 と z>=0 で別の線形写像を使う
# G(z) = a_l z + b_l  (z<0),  a_r z + b_r (z>=0)
def generator_piecewise(z: float, theta):
    a_l, b_l, a_r, b_r = theta
    if z < 0.0:
        return a_l * z + b_l
    return a_r * z + b_r


random.seed(21)
# 2峰を作りやすい初期値を置く(実務ではここも設計対象)
theta_pw0 = [0.35, -1.8, 0.35, 1.8]
phi_pw0 = [0.25, 0.0]

pw_theta, pw_phi, pw_history = train_gan(
    theta_pw0,
    phi_pw0,
    generator_piecewise,
    steps=400,
    lr_g=0.02,
    lr_d=0.03,
    d_updates=2,
)

for step, th, ph, ld, lg, w1, left, right in pw_history:
    print(
        f'step={step:03d}',
        f'theta={[round(v,4) for v in th]}',
        f'phi={[round(v,4) for v in ph]}',
        f'L_D={round(ld,4)}',
        f'L_G={round(lg,4)}',
        f'W1={round(w1,4)}',
        f'left={round(left,3)}',
        f'right={round(right,3)}'
    )

pw_stats = evaluate_distribution(pw_theta, generator_piecewise)
print()
print('piecewise generator final stats:')
for k, v in pw_stats.items():
    print(k, '=', round(v, 4))

損失の形が勾配をどう変えるかを見る

minimax と non-saturating の差を、識別器ロジットに対する勾配として見ます。ここが GAN の更新感覚をつかむ要所です。

print('comparison (linear vs piecewise):')
print('W1            =', round(lin_stats['w1'], 4), 'vs', round(pw_stats['w1'], 4))
print('fake stdev    =', round(lin_stats['fake_std'], 4), 'vs', round(pw_stats['fake_std'], 4))
print('mode balance  =', round(min(lin_stats['left'], lin_stats['right']) / max(lin_stats['left'], lin_stats['right']), 4),
      'vs',
      round(min(pw_stats['left'], pw_stats['right']) / max(pw_stats['left'], pw_stats['right']), 4))

non-saturating 目的を使う理由も確認します。s を識別器ロジット(D=σ(s))とすると、

です。D(G(z)) が小さい初期段階では、minimax の勾配は小さくなりやすく、non-saturating のほうが更新信号を確保しやすくなります。

def d_minimax_ds(p):
    return -p


def d_nonsat_ds(p):
    return -(1.0 - p)

for p in [0.001, 0.01, 0.05, 0.1, 0.5, 0.9]:
    print(
        f'D(fake)={p:>5}',
        f'|d(minimax)/ds|={abs(d_minimax_ds(p)):.4f}',
        f'|d(non-sat)/ds|={abs(d_nonsat_ds(p)):.4f}'
    )

GAN派生の代表例として LSGAN と WGAN の方向性を整理します。LSGAN は二乗誤差で勾配をなめらかにし、WGAN は Wasserstein 距離に基づく評価で学習安定化を狙います。どちらも「単に新しい損失」ではなく、失敗モードに対応した設計です。

次のセルでは次を観察してください。

def lsgan_losses(theta, phi, x_real, z_batch, generator_fn):
    x_fake = [generator_fn(z, theta) for z in z_batch]
    d_real = [discriminator(x, phi) for x in x_real]
    d_fake = [discriminator(x, phi) for x in x_fake]

    # 典型的な最小化形式
    loss_d = 0.5 * (
        sum((p - 1.0) ** 2 for p in d_real) / len(d_real)
        + sum((p - 0.0) ** 2 for p in d_fake) / len(d_fake)
    )
    loss_g = 0.5 * sum((p - 1.0) ** 2 for p in d_fake) / len(d_fake)
    return loss_d, loss_g


x_ref = sample_real(256)
z_ref = sample_z(256)
ld_gan, lg_gan, _, _ = losses_on_batch(pw_theta, pw_phi, x_ref, z_ref, generator_piecewise)
ld_ls, lg_ls = lsgan_losses(pw_theta, pw_phi, x_ref, z_ref, generator_piecewise)

print('GAN objective (maximize L_D, L_G):', round(ld_gan, 5), round(lg_gan, 5))
print('LSGAN loss (minimize loss_D, loss_G):', round(ld_ls, 5), round(lg_ls, 5))

critic_weight = 1.7
clip_value = 0.1
critic_weight_clipped = max(-clip_value, min(clip_value, critic_weight))
print('WGAN intuition: raw critic weight =', critic_weight, '-> clipped =', critic_weight_clipped)

GANで精度を上げるときは、損失関数だけでなく、生成器の分布表現力・識別器の強さ・学習率バランスを同時に設計する必要があります。このノートで見た通り、同じGANでも「分布をそもそも表現できるかどうか」で結果は大きく変わります。