状態予測モデル
状態予測モデルでは、未来予測を使った意思決定を、実装可能な形に分解して学びます。
当たる予測と、使える予測は同じではない
状態予測モデルは、次に何が起こるかを先回りして見るための土台です。ただし 1 ステップ先がそこそこ当たるだけでは不十分で、ロールアウトを伸ばしたときに崩れないかまで見ないと計画には使えません。
この notebook では、未来状態の予測をそのまま異常検知にもつなげて考えます。予測が外れた瞬間は、単なる失敗ではなく『いつもと違う』を見つける手掛かりにもなるからです。
1 ステップ誤差と多ステップ誤差を分けて読む
次の一歩だけなら当たるモデルでも、その予測を自分で食べながら先へ進むとすぐ崩れることがあります。ここで見たいのは、誤差の大きさそのものより、どのくらいの長さから累積が目立ち始めるかです。
この notebook の見どころ
初期遷移の確認、観測復元との役割分担、ロールアウトの崩れ方、計画候補への影響、最後に誤差監視までを一続きで見ます。
異常検知まで視野に入れるなら、平均誤差だけでは足りません。どの時点で、どの条件で、どの誤差が急に大きくなるかを見る必要があります。
誤差が意味を持つ瞬間
予測誤差は、モデル改善の手掛かりであると同時に運用上の警報でもあります。ずれが大きくなった時点を丁寧に見ると、モデルの弱い領域と異常候補の両方が見えてきます。
この notebook の立ち位置
ここでのロールアウトは小さな教材用の例ですが、1 ステップ評価と長期評価を分ける考え方は実務でもそのまま重要です。
実験 1: 状態予測モデルの初期遷移
未来状態予測の誤差を見るため、遷移初期値を定義します。
z_t = -0.05
a_t = 1.1
A, B = 0.89, 0.21
z_next = A * z_t + B * a_t
print('task = state-prediction-models')
print('z_next =', round(z_next, 6))
この遷移誤差を後続で時系列的に評価します。
実験 2: 観測予測を作る
次に、潜在状態から観測を復元する写像を作ります。状態推定と観測再現の役割分担をコードで掴みます。
def decode(z):
return {'position': 2.5 * z + 0.1, 'velocity': 0.8 * z - 0.05}
obs_next = decode(z_next)
print('obs_next =', {k: round(v, 4) for k, v in obs_next.items()})
print('keys =', list(obs_next.keys()))
観測予測を別関数に切ると、遷移誤差と観測誤差を分離して調整できます。
式と実装の往復
実験 3: ロールアウトを試す
ここで複数ステップ予測を実行します。1ステップでは見えない誤差累積を把握するためです。
actions = [0.0, 1.0, 1.0, 0.0, -0.5]
z = 0.1
traj = []
for a in actions:
z = 0.92 * z + 0.18 * a
traj.append(round(z, 5))
print('rollout =', traj)
長期予測で崩れるなら、遷移モデルの安定性や状態表現の情報量不足を疑います。
実験 4: 計画候補を比較する
次に、複数の行動列を比較して、どの計画が望ましいかを評価します。モデルベース強化学習の中心操作です。
plans = [[0, 1, 1], [1, 1, 1], [0, 0, 1]]
def score_plan(plan):
z = 0.1
for a in plan:
z = 0.92 * z + 0.18 * a
return z
scores = [round(score_plan(p), 5) for p in plans]
print('scores =', scores)
計画評価が可能になると、実環境での試行回数を抑えた探索がしやすくなります。
実験 5: モデル誤差を監視する
最後に、予測と実測の差を定量化します。世界モデルは『予測できる範囲』を常に点検する運用が重要です。
pred = [0.10, 0.22, 0.31, 0.29]
real = [0.11, 0.25, 0.28, 0.35]
errors = [abs(p - r) for p, r in zip(pred, real)]
print('errors =', [round(e, 4) for e in errors])
print('mean_error =', round(sum(errors) / len(errors), 5))
平均誤差だけでなく時点別誤差を追うと、どの遷移条件でモデルが弱いかを特定しやすくなります。異常検知に使うなら、たとえば mean_error ではなく各時点誤差の 95 パーセンタイルや固定閾値を超えた時点を異常候補とみなします。
読み終えたあとに残したい視点
- 1 ステップ先の正確さだけではモデルの価値は決まらない。
- 長期ロールアウトは誤差累積の見取り図になる。
- 予測誤差は、異常検知の入口としても読む価値がある。