最適化と正則化

深層学習では、モデルの表現力が高いだけでは十分ではありません。うまく下がるように最適化し、下がりすぎて訓練データへ貼り付きすぎないように正則化も考える必要があります。

このノートでは、同じ初期点から更新の仕方だけを変えて、学習率と weight decay が何を変えているのかを見ます。

同じ「損失が下がる」でも、更新の意味は同じではない

学習率を小さくした結果と、weight decay を入れた結果が、どちらも一見おとなしく見えることがあります。ですが前者は歩幅を縮め、後者は大きい重みを追加で押し戻しています。見た目が似ていても、更新の中身は別物です。

ここでは、見た目が似た 2 つの挙動をわざと並べます。ひとつは学習率を落として慎重に歩く更新、もうひとつは weight decay で大きい重みを嫌う更新です。両者を同じ初期点から走らせると、差はかなり露骨に出ます。

比較の主役は loss の値より、どんな軌道を通るかです

正則化の効果は 1 回の損失だけでは読み切れません。重みのノルム、更新後の位置、データ損失と total loss のズレを合わせて見てはじめて、「何を抑えているのか」が見えてきます。

ここでは toy 実験に割り切って、更新則だけを観察する

汎化性能の議論を本格的にやるには検証データや長い学習が必要です。この notebook ではそこまでは踏み込まず、まず更新式の差がその場でどう現れるかに集中します。

weight decay は「小さく進む」のではなく「大きい重みを嫌う」

この違いが腑に落ちると、学習率調整と正則化を同じ引き出しに入れなくて済みます。以後の deep learning ノートでも、loss に何が足され、更新へどう入るかを読む基礎になります。

import math

まずは同じ初期点を置く

最初に、全ての更新法が同じ場所から始まるようにします。比較の条件をそろえることで、後の差が見やすくなります。

x = [0.9, -0.3, 0.1]
w = [0.5, -0.2, 0.4]
b = 0.03
z = sum(xi * wi for xi, wi in zip(x, w)) + b
y = 1 / (1 + math.exp(-z))
print('task = optimization-regularization', 'init_pred=', round(y, 6))

同じ初期点を使うことで、手法差を比較しやすくします。

損失を分けて読む

ここではデータに対する外れを表す data_loss と、正則化項込みの total_loss を分けて見ます。正則化を足すと、この 2 つは同じ値ではなくなります。

target = 1.0
eps = 1e-9
loss = -(target * math.log(y + eps) + (1 - target) * math.log(1 - y + eps))
print('loss =', round(loss, 6))
print('error =', round(target - y, 6))

損失は『どれだけ外したか』を測る物差しです。物差しがない状態では、改善の方向を決められません。

更新式へ正則化が入ると何が変わるか

theta <- theta - eta grad に、重みを押し戻す項が入ると、単に loss を後から読み替えるだけではなく、実際の更新方向そのものが変わります。ここが weight decay の本体です。

1 ステップ更新を並べてみる

ここでは通常更新、weight decay あり、small lr を同じ条件で並べます。どれがどの方向へ、どれだけ進むかを比べると役割の違いが見えます。

lr = 0.1
grad_z = y - target  # BCE + sigmoid のとき dL/dz
w_new = [wi - lr * grad_z * xi for wi, xi in zip(w, x)]
b_new = b - lr * grad_z
z_new = sum(xi * wi for xi, wi in zip(x, w_new)) + b_new
y_new = 1 / (1 + math.exp(-z_new))
print('grad_z =', round(grad_z, 6))
print('b before/after =', round(b, 6), round(b_new, 6))
print('y before/after =', round(y, 6), round(y_new, 6))

この更新で損失が下がれば、勾配方向が合理的だったと言えます。下がらないなら学習率や符号を疑います。

ここでは重みだけでなく bias も同時に更新し、BCE と勾配の対応が崩れないようにしています。

weight decay は「小さく動く」のではなく「戻す」

ここで見るべきなのは、学習率を下げただけでは再現できない重みノルムの振る舞いです。weight decay は大きい重みほど強く押し戻すので、同じ step 幅の縮小とは別の効果を持ちます。

grad_w = [grad_z * xi for xi in x]


def predict_prob(sample, weights, bias):
    z_cur = sum(si * wi for si, wi in zip(sample, weights)) + bias
    return 1 / (1 + math.exp(-z_cur))


def bce_value(prob, target):
    return -(target * math.log(prob + eps) + (1 - target) * math.log(1 - prob + eps))


l2 = 0.1
small_lr = 0.03

before_prob = predict_prob(x, w, b)
before_loss = bce_value(before_prob, target)
before_norm = math.sqrt(sum(wi * wi for wi in w))
print('before:', 'pred=', round(before_prob, 6), 'data_loss=', round(before_loss, 6), 'weight_norm=', round(before_norm, 6))

updates = [
    ('plain', [wi - lr * gwi for wi, gwi in zip(w, grad_w)], b - lr * grad_z, 0.0),
    ('weight_decay', [wi - lr * (gwi + l2 * wi) for wi, gwi in zip(w, grad_w)], b - lr * grad_z, l2),
    ('small_lr', [wi - small_lr * gwi for wi, gwi in zip(w, grad_w)], b - small_lr * grad_z, 0.0),
]

for name, ww, bb, penalty in updates:
    prob = predict_prob(x, ww, bb)
    data_loss = bce_value(prob, target)
    weight_norm = math.sqrt(sum(wi * wi for wi in ww))
    total_loss = data_loss + 0.5 * penalty * sum(wi * wi for wi in ww)
    print(name, 'pred=', round(prob, 6), 'data_loss=', round(data_loss, 6), 'weight_norm=', round(weight_norm, 6), 'total_loss=', round(total_loss, 6))

weight decay は grad + l2 * w として更新式に入り、重みが大きい方向を追加で押し戻します。学習率を小さくするだけでは全方向の step を一様に縮めるだけなので、同じ動きにはなりません。

ミニバッチでも同じ考え方が残る

最後にミニバッチ平均勾配に weight decay を重ねます。更新単位が小さくなっても、「データから来る勾配」と「重みを押し戻す項」が別物であることは変わりません。

batch = [[0.8, -0.4, 0.2], [0.2, 0.1, -0.3], [0.5, -0.2, 0.7]]
targets = [1.0, 0.0, 1.0]
sample_preds = []
sample_losses = []
sample_grad_z = []

for bx, bt in zip(batch, targets):
    z_b = sum(xi * wi for xi, wi in zip(bx, w_new)) + b_new
    y_b = 1 / (1 + math.exp(-z_b))
    loss_b = bce_value(y_b, bt)
    sample_preds.append(y_b)
    sample_losses.append(loss_b)
    sample_grad_z.append(y_b - bt)

grad_w_batch = [
    sum(g * bx[j] for g, bx in zip(sample_grad_z, batch)) / len(batch)
    for j in range(len(w_new))
]
grad_b_batch = sum(sample_grad_z) / len(batch)

plain_batch_w = [wi - lr * gw for wi, gw in zip(w_new, grad_w_batch)]
plain_batch_b = b_new - lr * grad_b_batch
decay_batch_w = [wi - lr * (gw + l2 * wi) for wi, gw in zip(w_new, grad_w_batch)]
decay_batch_b = b_new - lr * grad_b_batch


def mean_batch_loss(weights, bias):
    vals = []
    for bx, bt in zip(batch, targets):
        z_b = sum(xi * wi for xi, wi in zip(bx, weights)) + bias
        y_b = 1 / (1 + math.exp(-z_b))
        vals.append(bce_value(y_b, bt))
    return sum(vals) / len(vals)


print('sample_preds =', [round(p, 4) for p in sample_preds])
print('sample_losses =', [round(v, 4) for v in sample_losses])
print('batch_loss_mean_before =', round(sum(sample_losses) / len(sample_losses), 6))
print('grad_b_batch =', round(grad_b_batch, 6))
print('grad_w_batch =', [round(v, 6) for v in grad_w_batch])
print('batch_loss_mean_plain_after =', round(mean_batch_loss(plain_batch_w, plain_batch_b), 6))
print('batch_loss_mean_decay_after =', round(mean_batch_loss(decay_batch_w, decay_batch_b), 6))
print('norm_plain/decay =', round(math.sqrt(sum(wi * wi for wi in plain_batch_w)), 6), round(math.sqrt(sum(wi * wi for wi in decay_batch_w)), 6))

ミニバッチでは各サンプルの損失と勾配を平均してから 1 回更新します。ここに weight decay を重ねると、平均勾配に加えて重みの大きさも同時に押し戻せるので、更新後のノルムの違いとして観察できます。

振り返り

このノートで大切なのは、最適化は下げる技術、正則化は崩れを抑える技術だと分けて読むことです。両方とも更新式へ入りますが、役割は違います。