制御モデルとモデルベース強化学習
制御モデルとモデルベース強化学習では、未来予測を使った意思決定を、実装可能な形に分解して学びます。
行動を入れた瞬間に、ただの予測から制御になる
世界モデルに行動 を入れると、『次に何が起きるかを見るモデル』から『次に何をさせたいかを考えるモデル』へ役割が変わります。ここで見たいのは、制御入力が状態をどう曲げ、その予測がどこまで計画に使えるかです。
モデルベース強化学習が得したいのは、実環境で高い授業料を払わずに候補行動をふるい落とすことです。その代わり、モデルの中で間違えた予測をすると悪い計画を選ぶ危険も増えます。以降では、潜在状態の遷移、観測復元、行動列の採点、誤差監視までを一つの流れとして読みます。
まず見るべきは行動の効き方
を変えたとき z_next がどう動くかを見れば、そのモデルが何を『制御できるもの』とみなしているかが分かります。行動列 plan の比較では、1 手ごとの良し悪しではなく、数ステップ先まで含めた軌跡全体を評価します。
この notebook の見どころ
制御入力が遷移に入る位置、ロールアウトが崩れたときにどんな失敗が起きるか、score_plan の差が小さいときに何を疑うべきかを順に確かめます。
ここでの採点関数は教育用の代理で、実際の MBRL では報酬・コスト・制約を別に計算して計画を選びます。それでも『モデル内で先に試す』という発想そのものはこの最小例で十分見えます。
誤差が方策に混ざる場所
制御モデルの怖さは、予測誤差がそのまま行動選択の誤りに変わることです。よさそうに見える計画が本当にいいのかは、予測値そのものより、誤差がどこで増えるかと合わせて読む必要があります。
読み方の軸
この notebook は大規模ロボティクスの実装ではなく、MBRL の中核を手で追うためのものです。『行動を入れたモデルで未来を先に採点する』という一点が掴めれば十分です。
観察 1: 制御モデルの遷移初期化
制御入力が状態へどう効くかを見るため、遷移係数を明示します。
z_t = -0.15
a_t = 0.8
A, B = 0.93, 0.16
z_next = A * z_t + B * a_t
print('task = control-model-and-mbrl')
print('z_next =', round(z_next, 6))
ここから計画候補評価へつなげます。
観察 2: 観測予測を作る
次に、潜在状態から観測を復元する写像を作ります。状態推定と観測再現の役割分担をコードで掴みます。
def decode(z):
return {'position': 2.5 * z + 0.1, 'velocity': 0.8 * z - 0.05}
obs_next = decode(z_next)
print('obs_next =', {k: round(v, 4) for k, v in obs_next.items()})
print('keys =', list(obs_next.keys()))
観測予測を別関数に切ると、遷移誤差と観測誤差を分離して調整できます。
計算の対応表
観察 3: ロールアウトを試す
ここで複数ステップ予測を実行します。1ステップでは見えない誤差累積を把握するためです。
actions = [0.0, 1.0, 1.0, 0.0, -0.5]
z = 0.1
traj = []
for a in actions:
z = 0.92 * z + 0.18 * a
traj.append(round(z, 5))
print('rollout =', traj)
長期予測で崩れるなら、遷移モデルの安定性や状態表現の情報量不足を疑います。
観察 4: 計画候補を比較する
次に、複数の行動列を比較して、どの計画が望ましいかを評価します。モデルベース強化学習の中心操作です。
plans = [[0, 1, 1], [1, 1, 1], [0, 0, 1]]
def score_plan(plan):
z = 0.1
for a in plan:
z = 0.92 * z + 0.18 * a
return z
scores = [round(score_plan(p), 5) for p in plans]
print('scores =', scores)
ここでの score_plan は、最終潜在値をそのまま比較するデモ用の代理指標です。MBRL では通常、予測した将来状態を報酬やコストへ変換して計画を選びます。
計画評価が可能になると、実環境での試行回数を抑えた探索がしやすくなります。
観察 5: モデル誤差を監視する
最後に、予測と実測の差を定量化します。世界モデルは『予測できる範囲』を常に点検する運用が重要です。
pred = [0.10, 0.22, 0.31, 0.29]
real = [0.11, 0.25, 0.28, 0.35]
errors = [abs(p - r) for p, r in zip(pred, real)]
print('errors =', [round(e, 4) for e in errors])
print('mean_error =', round(sum(errors) / len(errors), 5))
平均誤差だけでなく時点別誤差を追うと、どの遷移条件でモデルが弱いかを特定しやすくなります。
読み終えたあとに残したい視点
- MBRL の強みは、実環境で試す前に候補を落とせること。
- その代償として、モデル誤差がそのまま意思決定に混ざる。
- 制御モデルは精度だけでなく、どの行動で崩れるかまで監視して初めて使える。