マルチモーダルな世界モデル

実環境では、画像特徴・言語指示・行動履歴を同時に使って次状態を予測する必要があります。マルチモーダル融合は、この情報統合を行う中心技術です。

現実の次状態は、一種類の観測だけでは決まりにくい

ロボットやエージェントが次に何を見るかは、画像だけでも、言語だけでも、行動だけでも決まりません。世界モデルが頑健になるのは、複数の情報源を無理なくまとめられたときです。

この notebook では画像特徴、言語特徴、行動を並べて、単独入力と融合入力の差を見ます。深い融合ネットワークの完全版ではありませんが、そのぶん『どの情報が足りないと何が読めなくなるか』が見えやすくなっています。

融合の価値は、足し算より取りこぼしの減少にある

マルチモーダル化の狙いは、派手に情報量を増やすことよりも、一つのモダリティが弱い場面で判断を落とさないことにあります。画像だけでは見えない意図を言語が補い、言語だけでは足りない物理状態を行動履歴が補う、という関係に注目します。

この notebook の見どころ

image-onlytext-onlymultimodal の誤差差分を見ながら、どの入力の組み合わせが次状態の 2 成分を支えているかを読みます。

ここで使うのは線形予測器ですが、だからこそ各モダリティの寄与が見やすくなります。まずは融合の必要性を理解し、そのあと複雑な深層融合へ進むのが自然です。

どの情報が欠けると何が困るか

単独モダリティの誤差を見ると、その入力だけでは拾えない成分が浮かびます。融合モデルは万能だから強いのではなく、取りこぼしを減らすから強いのだと読むと筋が通ります。

読み方の軸

この notebook は深層マルチモーダル世界モデルの完全版ではありません。融合の必要性と、情報源ごとの役割分担を最小例で確認するためのものです。

それぞれの情報源を並べる

まずは画像特徴、言語特徴、行動を別々の入力として用意し、次状態との関係を観察します。

import numpy as np
np.random.seed(9)

n = 240
img_feat = np.random.randn(n, 2)
text_feat = np.random.randn(n, 1)
action = np.random.randn(n, 1)

# 真の次状態(2次元)
y = (
    0.7 * img_feat
    + np.hstack([0.5 * text_feat, -0.3 * text_feat])
    + np.hstack([0.4 * action, 0.2 * action])
    + 0.05 * np.random.randn(n, 2)
)

単独入力と融合入力を比べる

次に予測器を当てて、どの情報を合わせると誤差が下がるかを確かめます。

def fit_linear(X, Y):
    X1 = np.hstack([X, np.ones((len(X), 1))])
    W = np.linalg.pinv(X1) @ Y
    pred = X1 @ W
    mse = np.mean((pred - Y) ** 2)
    return mse

mse_img = fit_linear(img_feat, y)
mse_text = fit_linear(text_feat, y)
mse_multi = fit_linear(np.hstack([img_feat, text_feat, action]), y)

print('MSE image-only =', round(mse_img, 6))
print('MSE text-only  =', round(mse_text, 6))
print('MSE multimodal =', round(mse_multi, 6))

マルチモーダル化の本質は、モダリティを増やすこと自体ではなく、必要な情報を取りこぼさないことです。世界モデルで融合を考えるときは、この視点が出発点になります。