TD(λ)

TD(λ)\lambda) は、1-step TD と Monte Carlo を別々の手法として扱わず、連続的につなぐ見方です。λ\lambda を動かすだけで、短い目標値と長い目標値の混合比を滑らかに変えられます。

参考動画(外部)

授業本編ではなく、別の説明で見直したいときの参考材料です。

λ-return は「いろいろな n-step return の平均」

TD(λ)\lambda) を難しく見せるのは記号の多さですが、発想はかなり単純です。1-step, 2-step, 3-step ... の目標値を用意して、それらを λ\lambda に応じて重み付きで混ぜています。

前向き視点では GtλG_t^\lambda をそのまま作れますが、毎回未来を全部見直すのは重いので、後ろ向き視点では eligibility trace で近い効果をオンラインに実装します。ここではその対応関係を読むのが目的です。

見るべきポイント

λ=0\lambda=0 から λo1\lambda o 1 で target がどう変わるか、trace が何を覚えているか、forward view と backward view がどうつながるかを見ます。

この notebook は λ\lambda の意味を掴むためのものです。trace の設計差そのものは次の eligibility trace notebook でさらに細かく見ます。

読み方の軸

λ\lambda は魔法の安定化パラメータではなく、どの長さの target をどれだけ混ぜるかを決める重みです。1-step と Monte Carlo を一つの軸に置き直すつもりで読んでください。

import math

gamma = 0.9
rewards = [0.4, -0.2, 0.8, 1.1]
v_boot = 0.5


def n_step_return(rews, n, gamma, v_boot):
    n = min(n, len(rews))
    g = 0.0
    for k in range(n):
        g += (gamma ** k) * rews[k]
    if n < len(rews):
        g += (gamma ** n) * v_boot
    return g


def lambda_return(rews, gamma, lam, v_boot):
    T = len(rews)
    out = 0.0
    for n in range(1, T + 1):
        w = (1 - lam) * (lam ** (n - 1)) if n < T else (lam ** (n - 1))
        out += w * n_step_return(rews, n, gamma, v_boot)
    return out

for lam in [0.0, 0.3, 0.7, 0.95]:
    print('lambda=', lam, 'G^lambda=', round(lambda_return(rewards, gamma, lam, v_boot), 6))

後ろ向き視点へ落とし込む

et(s)=γλet1(s)+1[st=s],V(s)V(s)+αδtet(s)e_t(s)=\gamma\lambda e_{t-1}(s)+\mathbf{1}[s_t=s],\quad V(s)\leftarrow V(s)+\alpha\,\delta_t\,e_t(s)

forward view を毎回作り直さず、現在の TD 誤差を過去へ配って近い効果を得るのが backward view です。

alpha = 0.15
lam = 0.7
states = ['s0', 's1', 's0', 's2']
rewards = [0.3, 0.1, 0.7, 0.0]
V = {'s0': 0.2, 's1': 0.4, 's2': 0.1}
E = {'s0': 0.0, 's1': 0.0, 's2': 0.0}

for t in range(len(states) - 1):
    s = states[t]
    s_next = states[t + 1]
    r = rewards[t]
    delta = r + gamma * V[s_next] - V[s]

    for k in E:
        E[k] *= gamma * lam
    E[s] += 1.0

    for k in V:
        V[k] += alpha * delta * E[k]

print('Updated V =', {k: round(v, 6) for k, v in V.items()})
print('Final trace =', {k: round(v, 6) for k, v in E.items()})