状態表現学習
状態表現学習では、未来予測を使った意思決定を、実装可能な形に分解して学びます。
いい表現とは、きれいに縮むことではなく、あとで使えること
状態表現学習では、観測を小さく潰すだけでは不十分です。復元できても計画に使えない表現なら、世界モデルの内部状態としては弱いままです。
ここでは学習済み encoder の代わりに、手で置いた圧縮状態 z を使います。近道ではありますが、そのぶん『どの情報を残すと下流の予測や計画が助かるのか』を直接読み取れます。
圧縮で削ってよい情報、削れない情報
位置や速度のどちらかをうまく潰せても、将来予測や行動評価に必要な情報まで捨てると表現は壊れます。この notebook では、z を単なる低次元ベクトルではなく、下流タスクに渡す作業用メモリとして見ます。
読み進めるときの視点
z_next が小さくても観測復元がどこまで持つか、ロールアウトを伸ばしたときに情報不足がどこで表面化するか、scores の比較で表現の善し悪しがどう現れるかを見ます。
良い表現を復元誤差だけで決めないことが、この notebook の主題です。『あとで予測や制御に使えるか』まで含めて初めて、状態表現の価値が決まります。
この notebook の読み方
ここでは手書きの z を使っているので、学習アルゴリズムそのものよりも判定基準に集中できます。表現の良し悪しを、圧縮率ではなく下流の振る舞いで読むつもりで進めてください。
どこで効いているかを見る
観測復元がそこそこ当たっていても、計画候補の順位が安定しないなら表現は弱いままです。表現学習は『情報を残せたか』ではなく『必要な判断を支えられるか』で読むのが筋です。
この最小例で見えてくること
実際の状態表現学習では encoder/decoder に加えて再構成損失、予測損失、対比損失などが入ります。ただ、その複雑さに入る前でも、よい表現の条件はこの小さな例でかなり見えてきます。
検証 1: 表現学習の遷移初期化
状態表現学習では圧縮状態の可用性が重要なので、簡易遷移から始めます。
z_t = 0.12
a_t = 0.5
A, B = 0.91, 0.19
z_next = A * z_t + B * a_t
print('task = state-representation-learning')
print('z_next =', round(z_next, 6))
このzをどれだけ情報保持できるかが論点になります。
検証 2: 観測予測を作る
次に、潜在状態から観測を復元する写像を作ります。状態推定と観測再現の役割分担をコードで掴みます。
def decode(z):
return {'position': 2.5 * z + 0.1, 'velocity': 0.8 * z - 0.05}
obs_next = decode(z_next)
print('obs_next =', {k: round(v, 4) for k, v in obs_next.items()})
print('keys =', list(obs_next.keys()))
観測予測を別関数に切ると、遷移誤差と観測誤差を分離して調整できます。
定義の確認
検証 3: ロールアウトを試す
ここで複数ステップ予測を実行します。1ステップでは見えない誤差累積を把握するためです。
actions = [0.0, 1.0, 1.0, 0.0, -0.5]
z = 0.1
traj = []
for a in actions:
z = 0.92 * z + 0.18 * a
traj.append(round(z, 5))
print('rollout =', traj)
長期予測で崩れるなら、遷移モデルの安定性や状態表現の情報量不足を疑います。
検証 4: 計画候補を比較する
次に、複数の行動列を比較して、どの計画が望ましいかを評価します。モデルベース強化学習の中心操作です。
plans = [[0, 1, 1], [1, 1, 1], [0, 0, 1]]
def score_plan(plan):
z = 0.1
for a in plan:
z = 0.92 * z + 0.18 * a
return z
scores = [round(score_plan(p), 5) for p in plans]
print('scores =', scores)
計画評価が可能になると、実環境での試行回数を抑えた探索がしやすくなります。
検証 5: モデル誤差を監視する
最後に、予測と実測の差を定量化します。世界モデルは『予測できる範囲』を常に点検する運用が重要です。
pred = [0.10, 0.22, 0.31, 0.29]
real = [0.11, 0.25, 0.28, 0.35]
errors = [abs(p - r) for p, r in zip(pred, real)]
print('errors =', [round(e, 4) for e in errors])
print('mean_error =', round(sum(errors) / len(errors), 5))
平均誤差だけでなく時点別誤差を追うと、どの遷移条件でモデルが弱いかを特定しやすくなります。
読み終えたあとに残したい視点
- 圧縮できることと、使える表現であることは別。
- 表現の評価は下流タスクまで見ないと不十分。
- 情報を削る設計は、予測と計画の両方を見て初めて調整できる。