再帰型ニューラルネットワーク(RNN/LSTM/GRU)

系列データでは、いまの入力だけ見ても答えが決まらないことがよくあります。RNN 系モデルは、そのために過去の情報を「隠れ状態」として持ち回ります。

このノートでは、単純 RNN の更新式から始めて、なぜ長い系列で苦しくなるのかを見て、そのあとで LSTM と GRU が何を足して改善しているのかを追います。

系列モデルの本題は、いまの入力より「前に見たもの」をどう残すかです

系列タスクでは、現在の単語や値だけでは判断できないことがよくあります。RNN はそのために隠れ状態を持ち回りますが、その仕組みは長い系列になると急に苦しくなります。

このノートでは、単純 RNN の更新式を追って記憶の流れを見たあと、勾配消失がなぜ起こるかを確認します。そのうえで LSTM と GRU が、何を忘れ、何を通すためにゲートを入れたのかを整理します。

読み筋は「記憶」と「壊れ方」の 2 段構えです

前半では、隠れ状態が系列の前半の情報をどう運ぶかを見ます。後半では、その運び方が長くなるとどこで壊れるかを見て、LSTM / GRU がその壊れ方にどう手を打っているかへ進みます。

単純 RNN は小さいが、系列モデルの本質を全部含んでいる

前時刻の状態を次へ渡す。この一点だけで、系列モデルの核はほぼ出そろっています。だからこそ、まずは単純 RNN をきちんと読んだ方が、後のゲート付きモデルの意味も見えやすくなります。

LSTM と GRU は「記憶力が高い別モデル」ではなく、情報の交通整理を追加したもの

何を保持し、何を捨て、何を露出させるか。その制御を学習可能にしたのがゲートです。この観点で読むと、式が増えても役割の違いを追いやすくなります。

最後の比較は、性能勝負ではなく記憶の持ち方を見るためのもの

ここでの合成タスクは小さいですが、各モデルが遠い過去をどこまで引きずれるかを見るには十分です。数値の優劣だけでなく、どんな系列長で差が出るかに注目して読んでください。

まずは単純 RNN を 1 時刻ずつ追う

最初に、最も単純な RNN の更新を見ます。入力と前の隠れ状態から、次の隠れ状態がどう作られるかをここで掴みます。

import math
import numpy as np
import matplotlib.pyplot as plt

try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    TORCH_AVAILABLE = True
except ModuleNotFoundError:
    torch = None
    nn = None
    optim = None
    TORCH_AVAILABLE = False

最初に、最も単純なRNNの更新を実装します。

ここでは 1 次元入力を仮定します。
xtx_t は時刻 t の入力、hth_t は時刻 t の隠れ状態(記憶)です。
WxW_x は入力に掛かる重み、WhW_h は前時刻の記憶に掛かる重み、b はバイアスです。

hth_t = tanh(Wxtanh(W_x xtx_t + WhW_h ht1h_{t-1} + b)

この式の意味は「新しい入力 xtx_t と、ひとつ前の記憶 ht1h_{t-1} を混ぜて、新しい記憶 hth_t を作る」です。

def simple_rnn_forward(xs, wxh=0.9, whh=0.8, bh=0.0):
    h = 0.0
    history = []
    for t, x_t in enumerate(xs):
        # pre = W_x x_t + W_h h_{t-1} + b
        pre = wxh * x_t + whh * h + bh
        # h_t = tanh(pre)
        h = math.tanh(pre)
        history.append({'t': t, 'x_t': x_t, 'pre': pre, 'h_t': h})
    return h, history


sequence = [0.2, -0.1, 0.5, 0.3, -0.4, 0.1]
final_h, hist = simple_rnn_forward(sequence, wxh=1.0, whh=0.75, bh=0.0)

for row in hist:
    print(f"t={row['t']:>2d}, x_t={row['x_t']:>5.2f}, pre={row['pre']:>7.4f}, h_t={row['h_t']:>7.4f}")

print('final hidden state =', round(final_h, 4))

なぜ長い系列で苦しくなるのか

次は勾配消失です。時間をまたぐたびに係数の掛け算が続くと、遠い過去へ届く勾配が急速に小さくなることがあります。

t_axis = [r['t'] for r in hist]
x_axis = [r['x_t'] for r in hist]
h_axis = [r['h_t'] for r in hist]

plt.figure(figsize=(7.2, 3.6))
plt.plot(t_axis, x_axis, marker='o', label='input x_t')
plt.plot(t_axis, h_axis, marker='s', label='hidden h_t')
plt.axhline(0, color='#999999', linewidth=1)
plt.xlabel('time step t')
plt.ylabel('value')
plt.title('Simple RNN: input vs hidden state')
plt.legend()
plt.tight_layout()
plt.show()

RNNの弱点として、長い系列になると初期の情報が消えやすい問題があります。
逆伝播では、時刻をまたいで勾配が連鎖的に掛け算されるため、係数が1より小さいと急速に小さくなります。
この現象が勾配消失(vanishing gradient)です。

下の可視化は直感用の近似で、|W_h|^T のスカラー連鎖だけを取り出して見ています。
実際には tanh の導関数やヤコビアン積も効くので、ここでの図は「雰囲気の把握」が目的です。

def recurrent_chain_gain(whh, steps):
    return np.abs(whh) ** steps


steps = np.arange(1, 61)
for whh in [0.5, 0.9, 1.1]:
    gains = recurrent_chain_gain(whh, steps)
    print(f'whh={whh}: step 1 -> {gains[0]:.6f}, step 30 -> {gains[29]:.6f}, step 60 -> {gains[59]:.6f}')

plt.figure(figsize=(7.2, 3.6))
for whh in [0.5, 0.9, 1.1]:
    plt.plot(steps, recurrent_chain_gain(whh, steps), label=f'|whh|={whh}')
plt.yscale('log')
plt.xlabel('time distance')
plt.ylabel('gradient scale (log)')
plt.title('Why long-term dependency is hard for plain RNN')
plt.legend()
plt.tight_layout()
plt.show()

LSTM は何を足しているのか

LSTM は、忘れるか、足すか、出すかをゲートで調整します。単純 RNN の「全部まとめて流す」構造より、記憶の交通整理がしやすくなっています。

def sigmoid(z):
    return 1.0 / (1.0 + np.exp(-z))


def init_lstm_params(input_size=1, hidden_size=2, seed=7):
    rng = np.random.default_rng(seed)
    concat_size = input_size + hidden_size

    def w(shape):
        return rng.normal(0.0, 0.35, size=shape)

    params = {
        'W_i': w((hidden_size, concat_size)),
        'b_i': np.zeros(hidden_size),
        'W_f': w((hidden_size, concat_size)),
        'b_f': np.zeros(hidden_size),
        'W_g': w((hidden_size, concat_size)),
        'b_g': np.zeros(hidden_size),
        'W_o': w((hidden_size, concat_size)),
        'b_o': np.zeros(hidden_size),
    }
    return params


def lstm_step(x_t, h_prev, c_prev, p):
    # x_t:(1,), h_prev/c_prev:(2,), W_*:(2,3), b_*:(2,)
    z = np.concatenate([x_t, h_prev])

    i_t = sigmoid(p['W_i'] @ z + p['b_i'])
    f_t = sigmoid(p['W_f'] @ z + p['b_f'])
    g_t = np.tanh(p['W_g'] @ z + p['b_g'])
    o_t = sigmoid(p['W_o'] @ z + p['b_o'])

    c_t = f_t * c_prev + i_t * g_t
    h_t = o_t * np.tanh(c_t)

    return h_t, c_t, {'i_t': i_t, 'f_t': f_t, 'g_t': g_t, 'o_t': o_t}


params = init_lstm_params(input_size=1, hidden_size=2, seed=7)
xs = [np.array([v]) for v in [0.6, -0.2, 0.1, 0.5, -0.4]]
h = np.zeros(2)
c = np.zeros(2)

for t, x_t in enumerate(xs):
    h, c, gates = lstm_step(x_t, h, c, params)
    print(f't={t}, x={x_t[0]:>5.2f}')
    print('  i_t=', np.round(gates['i_t'], 4), 'f_t=', np.round(gates['f_t'], 4), 'o_t=', np.round(gates['o_t'], 4))
    print('  c_t=', np.round(c, 4), 'h_t=', np.round(h, 4))

print('読み方の目安: f_t が 1 に近いほど過去記憶を残し、0 に近いほど忘れます。')

GRU はどこを簡略化したのか

GRU は LSTM より少し軽い設計で、セル状態を分けずに隠れ状態だけで更新します。何を省いて何を残したかに注目すると、LSTM との違いが見やすくなります。

def count_rnn_params_single_bias(input_size, hidden_size):
    # 理論の最小形(重み1組 + bias1組)
    return hidden_size * (input_size + hidden_size) + hidden_size


def count_params_pytorch_style(gates, input_size, hidden_size):
    # PyTorchは bias_ih と bias_hh の2本を持つ
    return gates * (hidden_size * input_size + hidden_size * hidden_size + 2 * hidden_size)


input_size = 16
hidden_size = 64

print('--- single-bias theoretical count ---')
print('RNN :', count_rnn_params_single_bias(input_size, hidden_size))
print('LSTM:', 4 * count_rnn_params_single_bias(input_size, hidden_size))
print('GRU :', 3 * count_rnn_params_single_bias(input_size, hidden_size))

print('--- PyTorch parameter count (bias_ih + bias_hh) ---')
print('RNN :', count_params_pytorch_style(1, input_size, hidden_size))
print('LSTM:', count_params_pytorch_style(4, input_size, hidden_size))
print('GRU :', count_params_pytorch_style(3, input_size, hidden_size))

同じ系列課題で比べてみる

最後に PyTorch で RNN / LSTM / GRU を同じ課題にかけます。ここでは「系列の最初の情報を最後まで持てるか」が勝負なので、記憶保持の差が見やすい題材です。

def init_gru_params(input_size=1, hidden_size=2, seed=9):
    rng = np.random.default_rng(seed)
    concat_size = input_size + hidden_size

    def w(shape):
        return rng.normal(0.0, 0.35, size=shape)

    return {
        'W_z': w((hidden_size, concat_size)),
        'b_z': np.zeros(hidden_size),
        'W_r': w((hidden_size, concat_size)),
        'b_r': np.zeros(hidden_size),
        'W_h': w((hidden_size, concat_size)),
        'b_h': np.zeros(hidden_size),
    }


def gru_step(x_t, h_prev, p):
    z_in = np.concatenate([x_t, h_prev])
    z_t = sigmoid(p['W_z'] @ z_in + p['b_z'])
    r_t = sigmoid(p['W_r'] @ z_in + p['b_r'])

    h_candidate_in = np.concatenate([x_t, r_t * h_prev])
    h_tilde = np.tanh(p['W_h'] @ h_candidate_in + p['b_h'])

    # PyTorchと同じ流儀: z_t は過去状態を残す割合
    h_t = z_t * h_prev + (1.0 - z_t) * h_tilde
    return h_t, {'z_t': z_t, 'r_t': r_t, 'h_tilde': h_tilde}


gru_params = init_gru_params(input_size=1, hidden_size=2, seed=9)
h = np.zeros(2)
for t, x_t in enumerate(xs):
    h, gates = gru_step(x_t, h, gru_params)
    print(f't={t}, x={x_t[0]:>5.2f}, z_t={np.round(gates["z_t"],4)}, r_t={np.round(gates["r_t"],4)}, h_t={np.round(h,4)}')

ここからはPyTorchで、RNN/LSTM/GRUの違いを同じ課題で比較します。
課題は「系列の最初の値が正か負かを、最後の時刻で判定する」です。
途中にノイズが多いため、初期情報を保持できるモデルほど有利です。

logit は確率に変換する前の値で、BCEWithLogitsLoss は2値分類の標準損失です。
予測時は sigmoid(logit) >= 0.5 を陽性判定に使います。

期待値としては RNN < (GRU, LSTM) になりやすいですが、データ分布やハイパーパラメータで逆転もあります。

if TORCH_AVAILABLE:
    np.random.seed(0)

    def build_memory_dataset(n_samples=768, seq_len=28):
        x = np.random.randn(n_samples, seq_len, 1).astype(np.float32)
        y = (x[:, 0, 0] > 0).astype(np.float32)
        return torch.tensor(x), torch.tensor(y)

    x_train, y_train = build_memory_dataset(n_samples=768, seq_len=28)
    x_val, y_val = build_memory_dataset(n_samples=256, seq_len=28)

    class SequenceClassifier(nn.Module):
        def __init__(self, cell_type='RNN', hidden_size=24):
            super().__init__()
            if cell_type == 'RNN':
                self.rnn = nn.RNN(input_size=1, hidden_size=hidden_size, batch_first=True)
            elif cell_type == 'LSTM':
                self.rnn = nn.LSTM(input_size=1, hidden_size=hidden_size, batch_first=True)
            elif cell_type == 'GRU':
                self.rnn = nn.GRU(input_size=1, hidden_size=hidden_size, batch_first=True)
            else:
                raise ValueError('Unknown cell_type')
            self.head = nn.Linear(hidden_size, 1)

        def forward(self, x):
            out, _ = self.rnn(x)
            last = out[:, -1, :]
            logit = self.head(last).squeeze(-1)
            return logit

    def train_and_eval(cell_type, seed=0, epochs=6, lr=1e-2):
        torch.manual_seed(seed)
        model = SequenceClassifier(cell_type=cell_type, hidden_size=24)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        batch_size = 64
        for _ in range(epochs):
            perm = torch.randperm(x_train.size(0))
            for i in range(0, x_train.size(0), batch_size):
                idx = perm[i:i+batch_size]
                xb, yb = x_train[idx], y_train[idx]

                optimizer.zero_grad()
                logits = model(xb)
                loss = criterion(logits, yb)
                loss.backward()
                optimizer.step()

        with torch.no_grad():
            logits = model(x_val)
            preds = (torch.sigmoid(logits) >= 0.5).float()
            acc = (preds == y_val).float().mean().item()
        return acc

    seeds = [0, 1, 2]
    for ct in ['RNN', 'LSTM', 'GRU']:
        accs = [train_and_eval(ct, seed=s) for s in seeds]
        print(f'{ct} accuracy: mean={np.mean(accs):.4f}, std={np.std(accs):.4f}, each={np.round(accs,4)}')

    print('注: hidden_sizeを固定した簡易比較で、パラメータ数を厳密に一致させた比較ではありません。')
else:
    print('PyTorch未導入のため比較実験セルはスキップしました。')

実務では、まず GRU か LSTM を基準にし、必要に応じて単純 RNN を比較に置くのが安全です。系列が長いほど、記憶の保持を明示的に助ける設計の価値が大きくなります。