再帰型ニューラルネットワーク(RNN/LSTM/GRU)
系列データでは、いまの入力だけ見ても答えが決まらないことがよくあります。RNN 系モデルは、そのために過去の情報を「隠れ状態」として持ち回ります。
このノートでは、単純 RNN の更新式から始めて、なぜ長い系列で苦しくなるのかを見て、そのあとで LSTM と GRU が何を足して改善しているのかを追います。
系列モデルの本題は、いまの入力より「前に見たもの」をどう残すかです
系列タスクでは、現在の単語や値だけでは判断できないことがよくあります。RNN はそのために隠れ状態を持ち回りますが、その仕組みは長い系列になると急に苦しくなります。
このノートでは、単純 RNN の更新式を追って記憶の流れを見たあと、勾配消失がなぜ起こるかを確認します。そのうえで LSTM と GRU が、何を忘れ、何を通すためにゲートを入れたのかを整理します。
読み筋は「記憶」と「壊れ方」の 2 段構えです
前半では、隠れ状態が系列の前半の情報をどう運ぶかを見ます。後半では、その運び方が長くなるとどこで壊れるかを見て、LSTM / GRU がその壊れ方にどう手を打っているかへ進みます。
単純 RNN は小さいが、系列モデルの本質を全部含んでいる
前時刻の状態を次へ渡す。この一点だけで、系列モデルの核はほぼ出そろっています。だからこそ、まずは単純 RNN をきちんと読んだ方が、後のゲート付きモデルの意味も見えやすくなります。
LSTM と GRU は「記憶力が高い別モデル」ではなく、情報の交通整理を追加したもの
何を保持し、何を捨て、何を露出させるか。その制御を学習可能にしたのがゲートです。この観点で読むと、式が増えても役割の違いを追いやすくなります。
最後の比較は、性能勝負ではなく記憶の持ち方を見るためのもの
ここでの合成タスクは小さいですが、各モデルが遠い過去をどこまで引きずれるかを見るには十分です。数値の優劣だけでなく、どんな系列長で差が出るかに注目して読んでください。
まずは単純 RNN を 1 時刻ずつ追う
最初に、最も単純な RNN の更新を見ます。入力と前の隠れ状態から、次の隠れ状態がどう作られるかをここで掴みます。
import math
import numpy as np
import matplotlib.pyplot as plt
try:
import torch
import torch.nn as nn
import torch.optim as optim
TORCH_AVAILABLE = True
except ModuleNotFoundError:
torch = None
nn = None
optim = None
TORCH_AVAILABLE = False
最初に、最も単純なRNNの更新を実装します。
ここでは 1 次元入力を仮定します。
は時刻 t の入力、 は時刻 t の隠れ状態(記憶)です。
は入力に掛かる重み、 は前時刻の記憶に掛かる重み、b はバイアスです。
= + + b)
この式の意味は「新しい入力 と、ひとつ前の記憶 を混ぜて、新しい記憶 を作る」です。
def simple_rnn_forward(xs, wxh=0.9, whh=0.8, bh=0.0):
h = 0.0
history = []
for t, x_t in enumerate(xs):
# pre = W_x x_t + W_h h_{t-1} + b
pre = wxh * x_t + whh * h + bh
# h_t = tanh(pre)
h = math.tanh(pre)
history.append({'t': t, 'x_t': x_t, 'pre': pre, 'h_t': h})
return h, history
sequence = [0.2, -0.1, 0.5, 0.3, -0.4, 0.1]
final_h, hist = simple_rnn_forward(sequence, wxh=1.0, whh=0.75, bh=0.0)
for row in hist:
print(f"t={row['t']:>2d}, x_t={row['x_t']:>5.2f}, pre={row['pre']:>7.4f}, h_t={row['h_t']:>7.4f}")
print('final hidden state =', round(final_h, 4))
なぜ長い系列で苦しくなるのか
次は勾配消失です。時間をまたぐたびに係数の掛け算が続くと、遠い過去へ届く勾配が急速に小さくなることがあります。
t_axis = [r['t'] for r in hist]
x_axis = [r['x_t'] for r in hist]
h_axis = [r['h_t'] for r in hist]
plt.figure(figsize=(7.2, 3.6))
plt.plot(t_axis, x_axis, marker='o', label='input x_t')
plt.plot(t_axis, h_axis, marker='s', label='hidden h_t')
plt.axhline(0, color='#999999', linewidth=1)
plt.xlabel('time step t')
plt.ylabel('value')
plt.title('Simple RNN: input vs hidden state')
plt.legend()
plt.tight_layout()
plt.show()
RNNの弱点として、長い系列になると初期の情報が消えやすい問題があります。
逆伝播では、時刻をまたいで勾配が連鎖的に掛け算されるため、係数が1より小さいと急速に小さくなります。
この現象が勾配消失(vanishing gradient)です。
下の可視化は直感用の近似で、|W_h|^T のスカラー連鎖だけを取り出して見ています。
実際には tanh の導関数やヤコビアン積も効くので、ここでの図は「雰囲気の把握」が目的です。
def recurrent_chain_gain(whh, steps):
return np.abs(whh) ** steps
steps = np.arange(1, 61)
for whh in [0.5, 0.9, 1.1]:
gains = recurrent_chain_gain(whh, steps)
print(f'whh={whh}: step 1 -> {gains[0]:.6f}, step 30 -> {gains[29]:.6f}, step 60 -> {gains[59]:.6f}')
plt.figure(figsize=(7.2, 3.6))
for whh in [0.5, 0.9, 1.1]:
plt.plot(steps, recurrent_chain_gain(whh, steps), label=f'|whh|={whh}')
plt.yscale('log')
plt.xlabel('time distance')
plt.ylabel('gradient scale (log)')
plt.title('Why long-term dependency is hard for plain RNN')
plt.legend()
plt.tight_layout()
plt.show()
LSTM は何を足しているのか
LSTM は、忘れるか、足すか、出すかをゲートで調整します。単純 RNN の「全部まとめて流す」構造より、記憶の交通整理がしやすくなっています。
def sigmoid(z):
return 1.0 / (1.0 + np.exp(-z))
def init_lstm_params(input_size=1, hidden_size=2, seed=7):
rng = np.random.default_rng(seed)
concat_size = input_size + hidden_size
def w(shape):
return rng.normal(0.0, 0.35, size=shape)
params = {
'W_i': w((hidden_size, concat_size)),
'b_i': np.zeros(hidden_size),
'W_f': w((hidden_size, concat_size)),
'b_f': np.zeros(hidden_size),
'W_g': w((hidden_size, concat_size)),
'b_g': np.zeros(hidden_size),
'W_o': w((hidden_size, concat_size)),
'b_o': np.zeros(hidden_size),
}
return params
def lstm_step(x_t, h_prev, c_prev, p):
# x_t:(1,), h_prev/c_prev:(2,), W_*:(2,3), b_*:(2,)
z = np.concatenate([x_t, h_prev])
i_t = sigmoid(p['W_i'] @ z + p['b_i'])
f_t = sigmoid(p['W_f'] @ z + p['b_f'])
g_t = np.tanh(p['W_g'] @ z + p['b_g'])
o_t = sigmoid(p['W_o'] @ z + p['b_o'])
c_t = f_t * c_prev + i_t * g_t
h_t = o_t * np.tanh(c_t)
return h_t, c_t, {'i_t': i_t, 'f_t': f_t, 'g_t': g_t, 'o_t': o_t}
params = init_lstm_params(input_size=1, hidden_size=2, seed=7)
xs = [np.array([v]) for v in [0.6, -0.2, 0.1, 0.5, -0.4]]
h = np.zeros(2)
c = np.zeros(2)
for t, x_t in enumerate(xs):
h, c, gates = lstm_step(x_t, h, c, params)
print(f't={t}, x={x_t[0]:>5.2f}')
print(' i_t=', np.round(gates['i_t'], 4), 'f_t=', np.round(gates['f_t'], 4), 'o_t=', np.round(gates['o_t'], 4))
print(' c_t=', np.round(c, 4), 'h_t=', np.round(h, 4))
print('読み方の目安: f_t が 1 に近いほど過去記憶を残し、0 に近いほど忘れます。')
GRU はどこを簡略化したのか
GRU は LSTM より少し軽い設計で、セル状態を分けずに隠れ状態だけで更新します。何を省いて何を残したかに注目すると、LSTM との違いが見やすくなります。
def count_rnn_params_single_bias(input_size, hidden_size):
# 理論の最小形(重み1組 + bias1組)
return hidden_size * (input_size + hidden_size) + hidden_size
def count_params_pytorch_style(gates, input_size, hidden_size):
# PyTorchは bias_ih と bias_hh の2本を持つ
return gates * (hidden_size * input_size + hidden_size * hidden_size + 2 * hidden_size)
input_size = 16
hidden_size = 64
print('--- single-bias theoretical count ---')
print('RNN :', count_rnn_params_single_bias(input_size, hidden_size))
print('LSTM:', 4 * count_rnn_params_single_bias(input_size, hidden_size))
print('GRU :', 3 * count_rnn_params_single_bias(input_size, hidden_size))
print('--- PyTorch parameter count (bias_ih + bias_hh) ---')
print('RNN :', count_params_pytorch_style(1, input_size, hidden_size))
print('LSTM:', count_params_pytorch_style(4, input_size, hidden_size))
print('GRU :', count_params_pytorch_style(3, input_size, hidden_size))
同じ系列課題で比べてみる
最後に PyTorch で RNN / LSTM / GRU を同じ課題にかけます。ここでは「系列の最初の情報を最後まで持てるか」が勝負なので、記憶保持の差が見やすい題材です。
def init_gru_params(input_size=1, hidden_size=2, seed=9):
rng = np.random.default_rng(seed)
concat_size = input_size + hidden_size
def w(shape):
return rng.normal(0.0, 0.35, size=shape)
return {
'W_z': w((hidden_size, concat_size)),
'b_z': np.zeros(hidden_size),
'W_r': w((hidden_size, concat_size)),
'b_r': np.zeros(hidden_size),
'W_h': w((hidden_size, concat_size)),
'b_h': np.zeros(hidden_size),
}
def gru_step(x_t, h_prev, p):
z_in = np.concatenate([x_t, h_prev])
z_t = sigmoid(p['W_z'] @ z_in + p['b_z'])
r_t = sigmoid(p['W_r'] @ z_in + p['b_r'])
h_candidate_in = np.concatenate([x_t, r_t * h_prev])
h_tilde = np.tanh(p['W_h'] @ h_candidate_in + p['b_h'])
# PyTorchと同じ流儀: z_t は過去状態を残す割合
h_t = z_t * h_prev + (1.0 - z_t) * h_tilde
return h_t, {'z_t': z_t, 'r_t': r_t, 'h_tilde': h_tilde}
gru_params = init_gru_params(input_size=1, hidden_size=2, seed=9)
h = np.zeros(2)
for t, x_t in enumerate(xs):
h, gates = gru_step(x_t, h, gru_params)
print(f't={t}, x={x_t[0]:>5.2f}, z_t={np.round(gates["z_t"],4)}, r_t={np.round(gates["r_t"],4)}, h_t={np.round(h,4)}')
ここからはPyTorchで、RNN/LSTM/GRUの違いを同じ課題で比較します。
課題は「系列の最初の値が正か負かを、最後の時刻で判定する」です。
途中にノイズが多いため、初期情報を保持できるモデルほど有利です。
logit は確率に変換する前の値で、BCEWithLogitsLoss は2値分類の標準損失です。
予測時は sigmoid(logit) >= 0.5 を陽性判定に使います。
期待値としては RNN < (GRU, LSTM) になりやすいですが、データ分布やハイパーパラメータで逆転もあります。
if TORCH_AVAILABLE:
np.random.seed(0)
def build_memory_dataset(n_samples=768, seq_len=28):
x = np.random.randn(n_samples, seq_len, 1).astype(np.float32)
y = (x[:, 0, 0] > 0).astype(np.float32)
return torch.tensor(x), torch.tensor(y)
x_train, y_train = build_memory_dataset(n_samples=768, seq_len=28)
x_val, y_val = build_memory_dataset(n_samples=256, seq_len=28)
class SequenceClassifier(nn.Module):
def __init__(self, cell_type='RNN', hidden_size=24):
super().__init__()
if cell_type == 'RNN':
self.rnn = nn.RNN(input_size=1, hidden_size=hidden_size, batch_first=True)
elif cell_type == 'LSTM':
self.rnn = nn.LSTM(input_size=1, hidden_size=hidden_size, batch_first=True)
elif cell_type == 'GRU':
self.rnn = nn.GRU(input_size=1, hidden_size=hidden_size, batch_first=True)
else:
raise ValueError('Unknown cell_type')
self.head = nn.Linear(hidden_size, 1)
def forward(self, x):
out, _ = self.rnn(x)
last = out[:, -1, :]
logit = self.head(last).squeeze(-1)
return logit
def train_and_eval(cell_type, seed=0, epochs=6, lr=1e-2):
torch.manual_seed(seed)
model = SequenceClassifier(cell_type=cell_type, hidden_size=24)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
batch_size = 64
for _ in range(epochs):
perm = torch.randperm(x_train.size(0))
for i in range(0, x_train.size(0), batch_size):
idx = perm[i:i+batch_size]
xb, yb = x_train[idx], y_train[idx]
optimizer.zero_grad()
logits = model(xb)
loss = criterion(logits, yb)
loss.backward()
optimizer.step()
with torch.no_grad():
logits = model(x_val)
preds = (torch.sigmoid(logits) >= 0.5).float()
acc = (preds == y_val).float().mean().item()
return acc
seeds = [0, 1, 2]
for ct in ['RNN', 'LSTM', 'GRU']:
accs = [train_and_eval(ct, seed=s) for s in seeds]
print(f'{ct} accuracy: mean={np.mean(accs):.4f}, std={np.std(accs):.4f}, each={np.round(accs,4)}')
print('注: hidden_sizeを固定した簡易比較で、パラメータ数を厳密に一致させた比較ではありません。')
else:
print('PyTorch未導入のため比較実験セルはスキップしました。')
実務では、まず GRU か LSTM を基準にし、必要に応じて単純 RNN を比較に置くのが安全です。系列が長いほど、記憶の保持を明示的に助ける設計の価値が大きくなります。