深層強化学習

表形式の Q 学習は、状態数が少ないうちは明快ですが、観測が大きくなるとすぐ持ちきれなくなります。深層強化学習は、その表をニューラルネットへ置き換えて一般化を持ち込む発想から始まります。

参考動画(外部)

授業本編ではなく、別の説明で見直したいときの参考材料です。

何が「深層」に置き換わるのか

変わるのはベルマン更新そのものではなく、Q(s,a) の持ち方です。表に値を書く代わりに、状態特徴を入力したネットワークが各行動の Q 値を出すようになります。その結果、一つの更新が近い状態へ波及する一般化が起きます。

この notebook は DQN 系の最小骨格だけに絞っています。target network や replay buffer はまだ入れず、まず target=r+γmaxqnexttarget = r + \gamma \max q_next をネットワーク近似に乗せると何が起こるかを見る構成です。

この notebook の見どころ

one-hot 状態をネットへ入れて Q を出す流れ、終端遷移で bootstrap を切る条件、更新が他状態へ波及する一般化の入口を見ます。

ここで扱うのは深層強化学習の完成形ではなく、表形式 Q 学習から neural approximation へ渡る橋です。深く見るべき差分は、式よりも表現の持ち方にあります。

読み方の軸

表をネットに置き換えても、target の中身は変わりません。何が残り、何が新しく難しくなるのかを切り分けて読むと、DQN 系の構成要素が整理しやすくなります。

ネットワークと更新則を用意する

まずは表形式の Q テーブルの代わりになる最小 MLP と、TD target でそれを更新するための補助関数を定義します。

import numpy as np
np.random.seed(0)

gamma = 0.95
lr = 0.03

# transition = (state, action, reward, next_state, done)
transitions = [
    (0, 1, 1.0, 1, False),
    (0, 0, 0.2, 0, False),
    (1, 0, 0.5, 0, False),
    (1, 1, 1.2, 1, True),
]

W1 = 0.3 * np.random.randn(2, 8)
b1 = np.zeros(8)
W2 = 0.3 * np.random.randn(8, 2)
b2 = np.zeros(2)


def one_hot(s):
    x = np.zeros(2)
    x[s] = 1.0
    return x


def forward(x):
    h_pre = x @ W1 + b1
    h = np.tanh(h_pre)
    q = h @ W2 + b2
    return h_pre, h, q

終端遷移を含む最小学習ループを見る

次に、通常遷移と終端遷移を混ぜた小さな学習ループを回し、targeterr がどこで変わるかを追います。

terminal_updates = 0

for step in range(1200):
    s, a, r, s_next, done = transitions[np.random.randint(len(transitions))]
    x = one_hot(s)
    xn = one_hot(s_next)

    h_pre, h, q = forward(x)
    _, _, q_next = forward(xn)

    if done:
        terminal_updates += 1
        target = r
    else:
        target = r + gamma * np.max(q_next)
    err = q[a] - target

    # backprop for selected action loss 0.5*(q[a]-target)^2
    dq = np.zeros_like(q)
    dq[a] = err

    dW2 = np.outer(h, dq)
    db2 = dq
    dh = W2 @ dq
    dh_pre = dh * (1.0 - np.tanh(h_pre) ** 2)
    dW1 = np.outer(x, dh_pre)
    db1 = dh_pre

    W2 -= lr * dW2
    b2 -= lr * db2
    W1 -= lr * dW1
    b1 -= lr * db1

    if step % 300 == 0:
        q0 = forward(one_hot(0))[2]
        q1 = forward(one_hot(1))[2]
        print(f'step={step}: Q(s0)={np.round(q0,4)}, Q(s1)={np.round(q1,4)}, terminal_updates={terminal_updates}')

ネットワーク近似では、更新が他状態へ波及する一般化効果が得られます。この利点と不安定さが同時に出るので、実際の DQN では target network や replay buffer が必要になります。