状態空間モデル
状態空間モデルでは、未来予測を使った意思決定を、実装可能な形に分解して学びます。
観測だけでは足りないとき、状態を仮定する
状態空間モデルの発想は単純です。見えている値がノイジーだったり欠けていたりするなら、その裏にもっと滑らかに進む『状態』を置いてしまおう、という考え方です。
この notebook では position と velocity を観測側に出しつつ、内部では潜在状態 z を遷移させます。大事なのは数式を暗記することではなく、『状態遷移』と『観測生成』を分けると何が読みやすくなるかを掴むことです。
状態と観測を分けると何が助かるか
観測は外から見える量、状態はその背後で連続的に動いている内部の要約です。観測がぶれても、状態がなめらかに追えていれば長い予測は安定しやすくなります。ここでは線形の最小例で、その分業をそのまま見ます。
見るべきポイント
z_next が遷移式どおり動くか、decode が状態をどう観測へ戻すか、ロールアウトを伸ばしたときにどこから誤差が増えるかを追います。
線形 SSM は入口にすぎませんが、状態と観測を分ける見方は非線形モデルやカルマンフィルタ系にもそのまま残ります。まずはこの小さい系で『何を状態と呼ぶのか』をはっきりさせます。
数式より先に見てほしいこと
観測の 2 つのキーが同時に動いていても、内部では 1 つの遷移規則がそれを支えているかもしれません。状態空間モデルは、その隠れた共通原因を置くための道具として読むと自然です。
この notebook の立ち位置
本格的な推定アルゴリズムまでは踏み込みません。ここでは『隠れ状態を置くと予測と解釈がどう整理されるか』に絞って見ます。
観察 1: 状態空間モデルの遷移
状態空間モデルとして、線形遷移の最小形を定義します。
z_t = 0.35
a_t = -0.3
A, B = 0.88, 0.22
z_next = A * z_t + B * a_t
print('task = state-space-models')
print('z_next =', round(z_next, 6))
この線形遷移を起点に表現学習へ接続します。ここでは A、B、decode を手書きで置く最小例ですが、実際の状態空間モデルでは観測データからこれらの係数を推定・学習します。
観察 2: 観測予測を作る
次に、潜在状態から観測を復元する写像を作ります。状態推定と観測再現の役割分担をコードで掴みます。
def decode(z):
return {'position': 2.5 * z + 0.1, 'velocity': 0.8 * z - 0.05}
obs_next = decode(z_next)
print('obs_next =', {k: round(v, 4) for k, v in obs_next.items()})
print('keys =', list(obs_next.keys()))
観測予測を別関数に切ると、遷移誤差と観測誤差を分離して調整できます。
計算の対応表
観察 3: ロールアウトを試す
ここで複数ステップ予測を実行します。1ステップでは見えない誤差累積を把握するためです。
actions = [0.0, 1.0, 1.0, 0.0, -0.5]
z = 0.1
traj = []
for a in actions:
z = 0.92 * z + 0.18 * a
traj.append(round(z, 5))
print('rollout =', traj)
長期予測で崩れるなら、遷移モデルの安定性や状態表現の情報量不足を疑います。
観察 4: 計画候補を比較する
次に、複数の行動列を比較して、どの計画が望ましいかを評価します。モデルベース強化学習の中心操作です。
plans = [[0, 1, 1], [1, 1, 1], [0, 0, 1]]
def score_plan(plan):
z = 0.1
for a in plan:
z = 0.92 * z + 0.18 * a
return z
scores = [round(score_plan(p), 5) for p in plans]
print('scores =', scores)
計画評価が可能になると、実環境での試行回数を抑えた探索がしやすくなります。
観察 5: モデル誤差を監視する
最後に、予測と実測の差を定量化します。世界モデルは『予測できる範囲』を常に点検する運用が重要です。
pred = [0.10, 0.22, 0.31, 0.29]
real = [0.11, 0.25, 0.28, 0.35]
errors = [abs(p - r) for p, r in zip(pred, real)]
print('errors =', [round(e, 4) for e in errors])
print('mean_error =', round(sum(errors) / len(errors), 5))
平均誤差だけでなく時点別誤差を追うと、どの遷移条件でモデルが弱いかを特定しやすくなります。
読み終えたあとに残したい視点
- 観測はそのまま内部状態ではない。
- 状態遷移と観測生成を分けると、どこで誤差が出たか切り分けやすい。
- 長期予測の安定性は、状態の置き方に強く依存する。