深層強化学習
表形式の Q 学習は、状態数が少ないうちは明快ですが、観測が大きくなるとすぐ持ちきれなくなります。深層強化学習は、その表をニューラルネットへ置き換えて一般化を持ち込む発想から始まります。
参考動画(外部)
授業本編ではなく、別の説明で見直したいときの参考材料です。
何が「深層」に置き換わるのか
変わるのはベルマン更新そのものではなく、Q(s,a) の持ち方です。表に値を書く代わりに、状態特徴を入力したネットワークが各行動の Q 値を出すようになります。その結果、一つの更新が近い状態へ波及する一般化が起きます。
この notebook は DQN 系の最小骨格だけに絞っています。target network や replay buffer はまだ入れず、まず をネットワーク近似に乗せると何が起こるかを見る構成です。
この notebook の見どころ
one-hot 状態をネットへ入れて Q を出す流れ、終端遷移で bootstrap を切る条件、更新が他状態へ波及する一般化の入口を見ます。
ここで扱うのは深層強化学習の完成形ではなく、表形式 Q 学習から neural approximation へ渡る橋です。深く見るべき差分は、式よりも表現の持ち方にあります。
読み方の軸
表をネットに置き換えても、target の中身は変わりません。何が残り、何が新しく難しくなるのかを切り分けて読むと、DQN 系の構成要素が整理しやすくなります。
ネットワークと更新則を用意する
まずは表形式の Q テーブルの代わりになる最小 MLP と、TD target でそれを更新するための補助関数を定義します。
import numpy as np
np.random.seed(0)
gamma = 0.95
lr = 0.03
# transition = (state, action, reward, next_state, done)
transitions = [
(0, 1, 1.0, 1, False),
(0, 0, 0.2, 0, False),
(1, 0, 0.5, 0, False),
(1, 1, 1.2, 1, True),
]
W1 = 0.3 * np.random.randn(2, 8)
b1 = np.zeros(8)
W2 = 0.3 * np.random.randn(8, 2)
b2 = np.zeros(2)
def one_hot(s):
x = np.zeros(2)
x[s] = 1.0
return x
def forward(x):
h_pre = x @ W1 + b1
h = np.tanh(h_pre)
q = h @ W2 + b2
return h_pre, h, q
終端遷移を含む最小学習ループを見る
次に、通常遷移と終端遷移を混ぜた小さな学習ループを回し、target と err がどこで変わるかを追います。
terminal_updates = 0
for step in range(1200):
s, a, r, s_next, done = transitions[np.random.randint(len(transitions))]
x = one_hot(s)
xn = one_hot(s_next)
h_pre, h, q = forward(x)
_, _, q_next = forward(xn)
if done:
terminal_updates += 1
target = r
else:
target = r + gamma * np.max(q_next)
err = q[a] - target
# backprop for selected action loss 0.5*(q[a]-target)^2
dq = np.zeros_like(q)
dq[a] = err
dW2 = np.outer(h, dq)
db2 = dq
dh = W2 @ dq
dh_pre = dh * (1.0 - np.tanh(h_pre) ** 2)
dW1 = np.outer(x, dh_pre)
db1 = dh_pre
W2 -= lr * dW2
b2 -= lr * db2
W1 -= lr * dW1
b1 -= lr * db1
if step % 300 == 0:
q0 = forward(one_hot(0))[2]
q1 = forward(one_hot(1))[2]
print(f'step={step}: Q(s0)={np.round(q0,4)}, Q(s1)={np.round(q1,4)}, terminal_updates={terminal_updates}')
ネットワーク近似では、更新が他状態へ波及する一般化効果が得られます。この利点と不安定さが同時に出るので、実際の DQN では target network や replay buffer が必要になります。