損失関数と勾配降下法
学習が進むとは、何となく精度が上がることではなく、「損失という物差しが小さくなる方向へパラメータが動くこと」です。
このノートでは、損失が何を測っているかを確認し、その勾配が更新式へどう入るか、学習率が大きいと何が起こるか、ミニバッチ平均は何をしているのかを、最小の計算で追います。
学習とは、結局どの数が下がっていく現象なのか
精度が上がるという言い方だけでは、更新式の意味がぼやけます。実際にパラメータを動かしているのは損失であり、勾配降下法はその数値を下げる向きへ一歩ずつ進む手続きです。
最初にやることは単純で、適当な初期値から出した予測がどれだけ外れているかを測ることです。その数値を 1 回の更新で本当に下げられるかを見ながら、学習率とミニバッチが更新の質感をどう変えるかへ進みます。
まず見るべきは、式よりも更新前後の差です
BCE や勾配の式を眺めるだけでは、学習はまだ実感しにくいものです。この notebook では、更新前の損失と更新後の損失を何度も並べて、式が本当にモデルを改善の方向へ押しているかを確認します。
学習率を変えると、同じ勾配でも振る舞いが変わる
勾配は方向を与えますが、どれだけ進むかは学習率が決めます。方向が正しくても歩幅が大きすぎれば飛び越え、小さすぎれば進みが見えません。ここではその感覚を最小の例で押さえます。
ミニバッチは妥協ではなく、更新を整える仕組みでもある
1 件ずつ見る SGD と、全件まとめる完全バッチの中間にあるのがミニバッチです。計算都合だけでなく、更新のばらつきと安定性を調整する設計として読んでください。
import math
まずは初期予測を固定する
最初に、損失を計算できる初期状態を作ります。ここではまだ学習しません。何も学習していない予測がどの程度ずれているかを基準にします。
x = [1.0, -0.5, 0.3]
w = [0.2, -0.4, 0.6]
b = 0.12
z = sum(xi * wi for xi, wi in zip(x, w)) + b
y = 1 / (1 + math.exp(-z))
print('task = loss-and-gradient-descent', 'pred=', round(y, 6))
この初期予測を使って、損失の勾配と更新方向を確認します。
損失は、予測の外れを数える物差し
予測が正解に近いほど損失は小さく、外れるほど大きくなります。ここで物差しをはっきりさせると、次の更新式が急に具体的になります。
target = 1.0
eps = 1e-9
loss = -(target * math.log(y + eps) + (1 - target) * math.log(1 - y + eps))
print('loss =', round(loss, 6))
print('error =', round(target - y, 6))
損失は『どれだけ外したか』を測る物差しです。物差しがない状態では、改善の方向を決められません。
式と更新を対応づける
z = Wx + b で予測の元を作り、損失を計算し、theta <- theta - eta grad で更新する。この 3 行が学習の最小単位です。
1 回だけ更新して、本当に下がるかを見る
ここでは勾配降下法を 1 ステップだけ回します。更新後に損失が下がっていれば、少なくとも方向としては合理的だったと言えます。
lr = 0.1
grad_z = y - target # BCE + sigmoid のとき dL/dz
w_new = [wi - lr * grad_z * xi for wi, xi in zip(w, x)]
b_new = b - lr * grad_z
z_new = sum(xi * wi for xi, wi in zip(x, w_new)) + b_new
y_new = 1 / (1 + math.exp(-z_new))
print('grad_z =', round(grad_z, 6))
print('b before/after =', round(b, 6), round(b_new, 6))
print('y before/after =', round(y, 6), round(y_new, 6))
この更新で損失が下がれば、勾配方向が合理的だったと言えます。下がらないなら学習率や符号を疑います。
ここでは重みだけでなく bias も同時に更新し、BCE と勾配の対応が崩れないようにしています。
正則化を足すと、何を抑えたいのか
次は損失へ penalty を足して、重みが大きくなりすぎるのを抑える感覚を見ます。ここではまだ「正則化があると物差しがどう変わるか」を感じる段階です。
l2 = 0.01
weight_sq = sum(wi * wi for wi in w_new)
y_reg = 1 / (1 + math.exp(-(sum(xi * wi for xi, wi in zip(x, w_new)) + b_new)))
data_loss_reg = -(target * math.log(y_reg + eps) + (1 - target) * math.log(1 - y_reg + eps))
loss_reg = data_loss_reg + l2 * weight_sq
print('data_loss@updated =', round(data_loss_reg, 6))
print('weight_sq =', round(weight_sq, 6))
print('regularized loss =', round(loss_reg, 6))
ここでは損失に penalty が足されることだけ確認しています。正則化が更新式そのものをどう変えるかは、後続の optimization-regularization ノートで詳しく扱います。
ミニバッチ平均は何をしているのか
最後に、サンプルごとの損失と勾配を平均してから更新します。1 件ずつの更新よりぶれを減らしつつ、全件更新より計算を小分けにできる、その中間がミニバッチです。
batch = [[0.8, -0.4, 0.2], [0.2, 0.1, -0.3], [0.5, -0.2, 0.7]]
targets = [1.0, 0.0, 1.0]
eps = 1e-7
sample_preds = []
sample_losses = []
sample_grad_z = []
for bx, bt in zip(batch, targets):
z_b = sum(xi * wi for xi, wi in zip(bx, w_new)) + b_new
y_b = 1 / (1 + math.exp(-z_b))
loss_b = -(bt * math.log(y_b + eps) + (1 - bt) * math.log(1 - y_b + eps))
sample_preds.append(y_b)
sample_losses.append(loss_b)
sample_grad_z.append(y_b - bt)
grad_w_batch = [
sum(g * bx[j] for g, bx in zip(sample_grad_z, batch)) / len(batch)
for j in range(len(w_new))
]
grad_b_batch = sum(sample_grad_z) / len(batch)
w_batch = [wi - lr * gw for wi, gw in zip(w_new, grad_w_batch)]
b_batch = b_new - lr * grad_b_batch
updated_losses = []
for bx, bt in zip(batch, targets):
z_b = sum(xi * wi for xi, wi in zip(bx, w_batch)) + b_batch
y_b = 1 / (1 + math.exp(-z_b))
updated_losses.append(-(bt * math.log(y_b + eps) + (1 - bt) * math.log(1 - y_b + eps)))
print('sample_preds =', [round(p, 4) for p in sample_preds])
print('sample_losses =', [round(v, 4) for v in sample_losses])
print('batch_loss_mean =', round(sum(sample_losses) / len(sample_losses), 6))
print('grad_b_batch =', round(grad_b_batch, 6))
print('grad_w_batch =', [round(v, 6) for v in grad_w_batch])
print('batch_loss_mean_after =', round(sum(updated_losses) / len(updated_losses), 6))
ミニバッチでは各サンプルの損失と勾配を平均してから 1 回更新します。1 件ずつの SGD よりぶれが減り、完全バッチより計算を小分けにできるので、その中間としてよく使われます。
要点整理
このノートで持ち帰るべきなのは、損失が物差しで、勾配が下げる方向で、学習率が歩幅だという見方です。勾配降下法は式よりも、この役割分担で覚える方が実戦的です。