事前学習と最小GPT実装

事前学習を理解するときに大事なのは、巨大なモデル名を並べることではなく、次トークン予測がどんな計算として実装されているかを追うことです。

このノートでは、最小GPT実装を題材にして、トークン化、Self-Attention、損失、逆伝播、Adam、生成までを一本の流れで見ます。抽象論としての pre-training ではなく、「実装を読むと何が分かるか」を中心に進みます。

この内容は mani1261790/mictogpt_walkthrough$ の walkthrough を土台に、Noema の LLM 節向けに組み込んだものです。


このノートの到達点

ここを最後まで読むと、次ができる状態を目指します。

  1. 次トークン予測が事前学習の中心にある理由を説明できる
  2. Transformer の最小構成を、式とコードの両方で追える
  3. loss.backward() の裏で何が起きているかを言葉で説明できる
  4. 小さな GPT 実装を読み、どこを改造すれば挙動が変わるか判断できる

目次

  1. 言語モデルの数学的定義
  2. データとトークナイザ
  3. 自作Autograd(連鎖律をコード化)
  4. パラメータ初期化
  5. 線形層・Softmax・RMSNorm
  6. Self-Attention理論と実装
  7. GPT本体(1トークンずつ前進)
  8. 損失とTeacher Forcing
  9. 逆伝播とAdam
  10. 学習ループ
  11. 推論(サンプリング)
  12. 注意重みの観察
  13. 計算量とボトルネック
  14. よくある実装上の落とし穴
  15. 追加課題
  16. 最終まとめ

1. 言語モデルの数学的定義

言語モデル(LM)は、トークン列 x1,x2,...,xTx_1, x_2, ..., x_T の同時確率を次のように分解します。

P(x1,...,xT)=t=1TP(xtx<t)P(x_1, ..., x_T) = \prod_{t=1}^{T} P(x_t \mid x_{<t})

数式が苦手な人向けに一言で言うと、
「文章全体のもっともらしさ = 1文字(1トークン)ずつ当てるゲームの掛け算」 です。

たとえば「ね」「こ」「です」を作るなら、

ここで x_{<t} は位置 t より前の文脈です。
つまり、LMの仕事は "次トークン分布を出す関数" を作ることです。

学習では負の対数尤度(クロスエントロピー)を最小化します。

L=1Tt=1TlogP(xttruex<t)\mathcal{L} = -\frac{1}{T}\sum_{t=1}^{T} \log P(x_t^{\text{true}} \mid x_{<t})

これも日本語で言うと、
「正解にどれだけ自信を持てたかの平均点」 です。

ログが出てくる理由は、掛け算を足し算に変えて扱いやすくするためです。
数式がわからなくても、「正解を高確率にするほど良い」 だけ覚えれば十分です。

このノートの全実装は、この式を最小限で実現しています。

2. 全体アーキテクチャ(図)

入力トークン列
   |
   v
[Token Embedding] + [Position Embedding]
   |
   v
(Layer x N)
  ├─ RMSNorm -> Multi-Head Self-Attention -> Residual Add
  └─ RMSNorm -> MLP (Linear-ReLU-Linear) -> Residual Add
   |
   v
LM Head (Linear to vocab logits)
   |
   v
Softmax -> 次トークン確率

本稿ではこの最小形を、外部DLフレームワークなしで実装しています。

3. 準備: importと乱数シード

# ファイル存在チェックに使う標準ライブラリ
import os
# 数学関数(log/expなど)に使う
import math
# 乱数(初期化・シャッフル・サンプリング)に使う
import random
# URLからデータを取得するために使う
import urllib.request

# 乱数シードを固定して再現性を持たせる
random.seed(42)

4. データ読み込み

ここでは名前データ(1行1サンプル)を使います。

注意: これは会話データではなく文字列モデリングです。ただし、前の文脈から次トークン分布を出すという事前学習の骨格自体は、この最小例でもそのまま観察できます。

# 学習データファイルがなければダウンロードする
if not os.path.exists('input.txt'):
    # 名前データセットのURL
    names_url = 'https://raw.githubusercontent.com/karpathy/makemore/refs/heads/master/names.txt'
    # ダウンロードしてローカルに保存
    urllib.request.urlretrieve(names_url, 'input.txt')

# 1行1名前として読み込み、空行を除去してdocsを作る
docs = [l.strip() for l in open('input.txt').read().strip().split('\n') if l.strip()]
# 学習順が偏らないようにシャッフル
random.shuffle(docs)

# データ件数を確認
print(f'num docs: {len(docs)}')
# 先頭サンプルを見て内容を把握
print('sample docs:', docs[:10])

5. トークナイザ(文字単位)

なぜ文字単位にするか

特別トークン

BOS(Beginning Of Sequence)を追加して、
開始・終了を同じトークンで表現します。

訓練例: [BOS] + 文字列 + [BOS]
これで"いつ終えるか"もモデルが学べます。

# 全データに含まれる文字集合を作ってソート(文字語彙)
uchars = sorted(set(''.join(docs)))
# 文字->ID の辞書
stoi = {ch: i for i, ch in enumerate(uchars)}
# ID->文字 の辞書
itos = {i: ch for ch, i in stoi.items()}

# 特殊トークンBOSは通常文字IDの次に置く
BOS = len(uchars)
# 語彙サイズは文字語彙 + BOS
vocab_size = len(uchars) + 1

# 文字列をトークンID列に変換
def encode(text):
    # 各文字をstoiでID化
    return [stoi[ch] for ch in text]

# トークンID列を文字列に戻す
def decode(token_ids):
    # 各IDをitosで文字に戻して連結
    return ''.join(itos[i] for i in token_ids)

# 語彙サイズの確認
print('vocab size (with BOS):', vocab_size)
# BOSのID確認
print('BOS id:', BOS)
# エンコード例
print('encode("anna"):', encode('anna'))
# デコード例
print('decode:', decode(encode('anna')))

6. Autogradを自作する理由

PyTorchなら loss.backward() で終わります。
ここではあえて自作し、"何が裏で起きているか"を理解します。

連鎖律(Chain Rule)

合成関数 z = f(y), y = g(x) に対して

dzdx=dzdydydx\frac{dz}{dx} = \frac{dz}{dy} \cdot \frac{dy}{dx}

言い換えると、
「最終結果への影響 = 途中への影響 × 入力が途中をどれだけ動かすか」 です。

さらに直感的には、

深層学習はこの掛け算を巨大グラフ上で繰り返すだけです。

# 計算グラフの1ノード(値と勾配)を表すクラス
class Value:
    # 属性を固定してメモリ使用量を抑える
    __slots__ = ('data', 'grad', '_children', '_local_grads')

    # data: 値, children/local_grads: 逆伝播情報
    def __init__(self, data, children=(), local_grads=()):
        # 順伝播で得たスカラー値
        self.data = data
        # このノードに流れる勾配(初期値0)
        self.grad = 0
        # このノードを作る元になった子ノード
        self._children = children
        # 各子に対する局所微分
        self._local_grads = local_grads

    # 加算ノードを作る
    def __add__(self, other):
        # 数値ならValue化して型を揃える
        other = other if isinstance(other, Value) else Value(other)
        # z=x+y の局所微分は dz/dx=1, dz/dy=1
        return Value(self.data + other.data, (self, other), (1, 1))

    # 乗算ノードを作る
    def __mul__(self, other):
        # 数値ならValue化して型を揃える
        other = other if isinstance(other, Value) else Value(other)
        # z=x*y の局所微分は dz/dx=y, dz/dy=x
        return Value(self.data * other.data, (self, other), (other.data, self.data))

    # べき乗ノード
    def __pow__(self, other):
        # d(x^n)/dx = n*x^(n-1)
        return Value(self.data ** other, (self,), (other * self.data ** (other - 1),))

    # 対数ノード
    def log(self):
        # d(log x)/dx = 1/x
        return Value(math.log(self.data), (self,), (1 / self.data,))

    # 指数ノード
    def exp(self):
        # exp(x)を一度計算して値と局所微分に再利用
        e = math.exp(self.data)
        return Value(e, (self,), (e,))

    # ReLUノード
    def relu(self):
        # x>0なら1、それ以外は0が局所微分
        return Value(max(0, self.data), (self,), (float(self.data > 0),))

    # 単項マイナス
    def __neg__(self):
        # -x = x * (-1)
        return self * -1

    # 右辺加算(sum対応)
    def __radd__(self, other):
        # other + self を self + other に委譲
        return self + other

    # 減算
    def __sub__(self, other):
        # x - y = x + (-y)
        return self + (-other)

    # 右辺減算
    def __rsub__(self, other):
        # other - self
        return other + (-self)

    # 右辺乗算
    def __rmul__(self, other):
        # other * self を self * other に委譲
        return self * other

    # 除算
    def __truediv__(self, other):
        # x / y = x * y^(-1)
        return self * (other ** -1)

    # 右辺除算
    def __rtruediv__(self, other):
        # other / self
        return other * (self ** -1)

    # lossから全ノードへ勾配を流す
    def backward(self):
        # 逆伝播順を保存するリスト
        topo = []
        # DFS訪問済み管理
        visited = set()

        # 計算グラフをトポロジカル順に並べる
        def build_topo(v):
            # 未訪問ノードのみ処理
            if v not in visited:
                # 訪問済みにする
                visited.add(v)
                # 子ノードを先に辿る
                for child in v._children:
                    build_topo(child)
                # 子の後に自分を積む
                topo.append(v)

        # 出力ノード(通常loss)から探索開始
        build_topo(self)
        # dself/dself=1 で初期化
        self.grad = 1

        # 出力側から入力側へ順に勾配を伝播
        for v in reversed(topo):
            # 各子へ局所微分×上流勾配を加算
            for child, local_grad in zip(v._children, v._local_grads):
                child.grad += local_grad * v.grad

6.1 計算グラフのイメージ

例: loss = (relu(a*b + c))^2

a ----*
      \ 
       (*)----+----ReLU----(^2)----loss
      /      /
b ----*      c

逆伝播では loss 側から順に局所微分を掛け合わせ、a,b,c に勾配を流します。

# 入力ノードa
a = Value(2.0)
# 入力ノードb
b = Value(-3.0)
# 入力ノードc
c = Value(10.0)
# d = a*b + c
# (中間ノード)
d = a * b + c
# e = ReLU(d)
e = d.relu()
# 最終損失 = e^2
loss_demo = e ** 2
# 逆伝播でa,b,cへの勾配を計算
loss_demo.backward()

# 損失値を表示
print('loss:', loss_demo.data)
# aに対する勾配
print('grad a:', a.grad)
# bに対する勾配
print('grad b:', b.grad)
# cに対する勾配
print('grad c:', c.grad)

7. モデルのハイパーパラメータと重み初期化

ここでは小さな設定にしています。

本物のLLMはこれらが数桁大きいですが、構造は同じです。

# 埋め込み次元
n_embd = 16
# ヘッド数
n_head = 4
# Transformer層数
n_layer = 1
# 最大系列長
block_size = 16
# 1ヘッドあたりの次元
head_dim = n_embd // n_head

# Value行列をガウス初期化で作る

def matrix(nout, nin, std=0.08):
    # 形状[nout x nin]の2次元配列を作る
    return [[Value(random.gauss(0, std)) for _ in range(nin)] for _ in range(nout)]

# 主要パラメータを辞書で管理
state_dict = {
    # トークン埋め込み
    'wte': matrix(vocab_size, n_embd),
    # 位置埋め込み
    'wpe': matrix(block_size, n_embd),
    # 出力ヘッド
    'lm_head': matrix(vocab_size, n_embd),
}

# 各層のAttention/MLP重みを追加
for i in range(n_layer):
    # Query重み
    state_dict[f'layer{i}.attn_wq'] = matrix(n_embd, n_embd)
    # Key重み
    state_dict[f'layer{i}.attn_wk'] = matrix(n_embd, n_embd)
    # Value重み
    state_dict[f'layer{i}.attn_wv'] = matrix(n_embd, n_embd)
    # Attention出力投影重み
    state_dict[f'layer{i}.attn_wo'] = matrix(n_embd, n_embd)
    # MLP前段(拡張)重み
    state_dict[f'layer{i}.mlp_fc1'] = matrix(4 * n_embd, n_embd)
    # MLP後段(圧縮)重み
    state_dict[f'layer{i}.mlp_fc2'] = matrix(n_embd, 4 * n_embd)

# 全パラメータを1本のリストに平坦化
params = [p for mat in state_dict.values() for row in mat for p in row]
# パラメータ数を表示
print('num params:', len(params))

8. 基本演算: Linear / Softmax / RMSNorm

Linear

y=Wxy = Wx

これは「重み付き合計」です。

数式が苦手なら、「入力に係数表をかけて新しい表現へ写す」 と捉えてください。

Softmax

pi=ezijezjp_i = \frac{e^{z_i}}{\sum_j e^{z_j}}

これは「点数の配列を、合計1の確率に変換する操作」です。

つまり、モデルの“生の点数”を**「次に何を出すかの割合」**へ変換しています。

数値安定化のために z_i - max(z) を使います。

RMSNorm

RMS(x)=1dixi2+ϵ,yi=xiRMS(x)\text{RMS}(x)=\sqrt{\frac{1}{d}\sum_i x_i^2 + \epsilon},\quad y_i=\frac{x_i}{\text{RMS}(x)}

これは「ベクトルの大きさを整える」処理です。

式を読まなくても、「毎回だいたい同じスケールに正規化する」 と理解すればOKです。

LayerNormより簡潔で、実務でもよく使われます。

# 全結合層(x:入力ベクトル, w:[出力次元 x 入力次元])
def linear(x, w):
    # 各出力ユニットを内積で計算
    return [sum(wi * xi for wi, xi in zip(wo, x)) for wo in w]

# Softmax(ロジットを確率分布に変換)
def softmax(logits):
    # 数値安定化のため最大値を引く
    max_val = max(v.data for v in logits)
    # 指数化
    exps = [(v - max_val).exp() for v in logits]
    # 分母(総和)
    total = sum(exps)
    # 合計1になるよう正規化
    return [e / total for e in exps]

# RMSNorm(ベクトルスケールを正規化)
def rmsnorm(x):
    # 成分二乗の平均
    ms = sum(xi * xi for xi in x) / len(x)
    # RMSの逆数スケール
    scale = (ms + 1e-5) ** -0.5
    # 全成分を同じスケールで調整
    return [xi * scale for xi in x]

9. Self-Attentionの理論

入力ベクトル列 XX から、

Q=XWQ,quadK=XWK,quadV=XWVQ=XW_Q, \\quad K=XW_K, \\quad V=XW_V

ここは次の3役を作っているだけです。

を作り、各位置で

Attn(Q,K,V)=softmax(QKTdh+M)V\text{Attn}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_h}} + M\right)V

日本語にすると、次の順です。

  1. QK を比べて「どこを見るべきか」の点数を作る
  2. Softmaxで「見る割合」に変える
  3. その割合で V を混ぜて、文脈を取り込んだ新しい表現を作る

/sqrt(d_h) は、次元が増えたとき点数が大きくなりすぎるのを防ぐスケール調整です。
式が難しければ、「必要な過去情報を、割合で集める仕組み」 と覚えれば十分です。

を計算します。M は未来を見ないためのマスクです。

この実装では、"時刻tまでしか keys/values に入っていない"ので、明示マスクなしで因果性を満たします。

9.1 図: 1ヘッド注意の流れ

現在位置 t の query q_t
        |
        v
過去0..tの keys と内積 -> スコア列 s_0..s_t
        |
        v
softmax(s) -> 重み a_0..a_t
        |
        v
values の加重和 -> 出力ベクトル o_t

重み aia_i は"どの過去位置をどれだけ参照するか"を意味します。

10. GPT本体(1トークンずつ計算)

この関数は1ステップ分を処理します。

入力:

出力:

return_attn=True で注意重み観察も可能にしています。

# 1トークン分の前向き計算を行う関数
def gpt(token_id, pos_id, keys, values, return_attn=False):
    # トークン埋め込みを取得
    tok_emb = state_dict['wte'][token_id]
    # 位置埋め込みを取得
    pos_emb = state_dict['wpe'][pos_id]
    # トークン情報と位置情報を合成
    x = [t + p for t, p in zip(tok_emb, pos_emb)]
    # 入力を正規化
    x = rmsnorm(x)

    # 注意重みを返したい場合の記録バッファ
    attn_records = [] if return_attn else None

    # Transformer層を順に適用
    for li in range(n_layer):
        # ===== Attention block =====
        # 残差接続用に入力を保存
        x_residual = x
        # Attention前の正規化
        x = rmsnorm(x)

        # 線形変換でQ/K/Vを作る
        q = linear(x, state_dict[f'layer{li}.attn_wq'])
        k = linear(x, state_dict[f'layer{li}.attn_wk'])
        v = linear(x, state_dict[f'layer{li}.attn_wv'])

        # 現在時刻のK/Vをキャッシュに追加
        keys[li].append(k)
        values[li].append(v)

        # ヘッド出力を連結するバッファ
        x_attn = []
        # ヘッドごとにAttentionを計算
        for h in range(n_head):
            # ヘッド開始オフセット
            hs = h * head_dim
            # このヘッドのQuery
            q_h = q[hs:hs + head_dim]
            # 過去を含むこのヘッドのKeys
            k_h = [ki[hs:hs + head_dim] for ki in keys[li]]
            # 過去を含むこのヘッドのValues
            v_h = [vi[hs:hs + head_dim] for vi in values[li]]

            # scaled dot-productでAttentionロジットを作る
            attn_logits = [
                sum(q_h[j] * k_h[t][j] for j in range(head_dim)) / (head_dim ** 0.5)
                for t in range(len(k_h))
            ]
            # ロジットを確率重みに変換
            attn_weights = softmax(attn_logits)

            # 可視化したい場合は生の重みを記録
            if return_attn:
                attn_records.append((li, h, [w.data for w in attn_weights]))

            # 重み付き和でヘッド出力を作る
            head_out = [
                sum(attn_weights[t] * v_h[t][j] for t in range(len(v_h)))
                for j in range(head_dim)
            ]
            # ヘッド出力を連結
            x_attn.extend(head_out)

        # ヘッド連結後の出力投影
        x = linear(x_attn, state_dict[f'layer{li}.attn_wo'])
        # Attention残差接続
        x = [a + b for a, b in zip(x, x_residual)]

        # ===== MLP block =====
        # 残差接続用に入力を保存
        x_residual = x
        # MLP前の正規化
        x = rmsnorm(x)
        # 次元拡張の線形層
        x = linear(x, state_dict[f'layer{li}.mlp_fc1'])
        # 非線形化(ReLU)
        x = [xi.relu() for xi in x]
        # 次元圧縮の線形層
        x = linear(x, state_dict[f'layer{li}.mlp_fc2'])
        # MLP残差接続
        x = [a + b for a, b in zip(x, x_residual)]

    # 最終的に語彙次元へ射影してロジットを得る
    logits = linear(x, state_dict['lm_head'])

    # 注意重みも欲しい場合は一緒に返す
    if return_attn:
        return logits, attn_records
    # 通常はロジットだけ返す
    return logits

11. 損失関数とTeacher Forcing

系列 x_0...x_n に対して、
各位置で (入力=x_t, 正解=x_{t+1}) を与えます。

これは Teacher Forcing と呼ばれます。
学習中は常に正解履歴を与えるため、安定して学習できます。

位置ごとの損失:

t=logpt(xt+1true)\ell_t = -\log p_t(x_{t+1}^{true})

これは「位置tで、正解次トークンに何点つけたかの減点」です。
正解確率が高いほど減点は小さく、低いほど大きくなります。

系列損失:

L=1ntt\mathcal{L}=\frac{1}{n}\sum_t \ell_t

つまり1文あたりの平均減点を最小化しています。
数式がつらければ、「全位置でのミスの平均」 とだけ捉えて問題ありません。

# 1つの文書に対する平均損失を計算する
def loss_for_document(doc):
    # 文頭/文末をBOSで挟んだトークン列を作る
    tokens = [BOS] + encode(doc) + [BOS]
    # 最大長block_sizeまでを学習対象にする
    n = min(block_size, len(tokens) - 1)

    # 層ごとのKVキャッシュを初期化
    keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
    # 時刻ごとの損失をためる
    losses = []

    # 各時刻で次トークン予測を行う
    for pos_id in range(n):
        # 現在トークン
        token_id = tokens[pos_id]
        # 正解は1つ先のトークン
        target_id = tokens[pos_id + 1]

        # ロジット計算
        logits = gpt(token_id, pos_id, keys, values)
        # 確率分布に変換
        probs = softmax(logits)
        # クロスエントロピー(-log p(correct))
        losses.append(-probs[target_id].log())

    # 系列平均損失を返す
    return (1 / n) * sum(losses)

# 動作確認用に先頭文書の損失を計算
doc0 = docs[0]
# 学習前損失
loss0 = loss_for_document(doc0)
# 対象文書を表示
print('doc:', doc0)
# 損失値を表示
print('loss before training:', round(loss0.data, 4))

12. 勾配チェック(考え方)

自作Autogradの検証法として、有限差分で微分を近似できます。

dLdθL(θ+ϵ)L(θϵ)2ϵ\frac{dL}{d\theta} \approx \frac{L(\theta+\epsilon)-L(\theta-\epsilon)}{2\epsilon}

直感は「パラメータをほんの少しだけ前後に動かし、損失の変化量を測る」です。
これで backward() が返す勾配の“だいたいの正しさ”を確認できます。

厳密一致は不要ですが、オーダーが近いかを確認できます。
(コストが高いので通常は少数パラメータで実施)

# 数値微分に使う微小量
eps_fd = 1e-4
# 検証用の文書
check_doc = docs[1]

# まずAutogradで勾配を計算
loss_fd = loss_for_document(check_doc)
# 逆伝播
loss_fd.backward()
# 先頭パラメータの勾配を取得
autograd_grad = params[0].grad

# 有限差分で同じ勾配を近似
p = params[0]
# 元の値を退避
orig = p.data

# theta + eps の損失
p.data = orig + eps_fd
loss_plus = loss_for_document(check_doc).data

# theta - eps の損失
p.data = orig - eps_fd
loss_minus = loss_for_document(check_doc).data

# パラメータを元に戻す
p.data = orig
# 中心差分で勾配近似
fd_grad = (loss_plus - loss_minus) / (2 * eps_fd)

# Autograd勾配を表示
print('autograd grad :', autograd_grad)
# 有限差分勾配を表示
print('finite diff   :', fd_grad)
# 差分を表示
print('abs diff      :', abs(autograd_grad - fd_grad))

# 後続学習に向けて全勾配をクリア
for p in params:
    # gradを0に戻す
    p.grad = 0

13. Adam最適化

Adamは勾配の1次モーメント(平均)と2次モーメント(分散的量)を使います。

mt=β1mt1+(1β1)gtm_t=\beta_1 m_{t-1} + (1-\beta_1)g_t vt=β2vt1+(1β2)gt2v_t=\beta_2 v_{t-1} + (1-\beta_2)g_t^2 m^t=mt1β1t,v^t=vt1β2t\hat m_t = \frac{m_t}{1-\beta_1^t},\quad \hat v_t = \frac{v_t}{1-\beta_2^t} θt=θt1ηm^tv^t+ϵ\theta_t = \theta_{t-1} - \eta \frac{\hat m_t}{\sqrt{\hat v_t}+\epsilon}

式を日本語で分解すると次の4段階です。

  1. mtm_t: 勾配の移動平均(進む向きの慣性)
  2. vtv_t: 勾配の二乗平均(振れ幅の大きさ)
  3. hat: 初期バイアス補正(最初の不正確さを補正)
  4. 更新式: 方向は mhatm_hat、歩幅は sqrt(vhat)sqrt(v_hat) で自動調整

数式が読めなくても、「荒い方向へは慎重に、小さい方向へは速く」 進む最適化器、
と理解できれば実装を追うのに十分です。

この実装では線形学習率減衰も入っています。

# 初期学習率
learning_rate = 0.01
# Adamのβ1, β2
beta1, beta2 = 0.85, 0.99
# Adamの数値安定化項
eps_adam = 1e-8

# 1次モーメントバッファ(平均方向)
m = [0.0] * len(params)
# 2次モーメントバッファ(分散スケール)
v = [0.0] * len(params)

# 学習1ステップ分を実行する関数
def train_step(step, num_steps):
    # 今回使う文書を巡回で選ぶ
    doc = docs[step % len(docs)]
    # BOSで挟んだトークン列を作る
    tokens = [BOS] + encode(doc) + [BOS]
    # block_size内に切り詰める
    n = min(block_size, len(tokens) - 1)

    # 層ごとのKVキャッシュ初期化
    keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
    # 位置ごとの損失を保存
    losses = []

    # 各位置で次トークン予測を行う
    for pos_id in range(n):
        # 現在トークン
        token_id = tokens[pos_id]
        # 正解トークン(1つ先)
        target_id = tokens[pos_id + 1]

        # ロジット計算
        logits = gpt(token_id, pos_id, keys, values)
        # 確率に変換
        probs = softmax(logits)
        # 位置ごとのクロスエントロピーを追加
        losses.append(-probs[target_id].log())

    # 系列平均損失
    loss = (1 / n) * sum(losses)
    # 逆伝播で全勾配を計算
    loss.backward()

    # 線形学習率減衰
    lr_t = learning_rate * (1 - step / num_steps)

    # 全パラメータをAdamで更新
    for i, p in enumerate(params):
        # 1次モーメント更新
        m[i] = beta1 * m[i] + (1 - beta1) * p.grad
        # 2次モーメント更新
        v[i] = beta2 * v[i] + (1 - beta2) * (p.grad ** 2)

        # バイアス補正済み1次モーメント
        m_hat = m[i] / (1 - beta1 ** (step + 1))
        # バイアス補正済み2次モーメント
        v_hat = v[i] / (1 - beta2 ** (step + 1))

        # パラメータ更新
        p.data -= lr_t * m_hat / (v_hat ** 0.5 + eps_adam)
        # 次ステップのためgradリセット
        p.grad = 0

    # 表示用に損失値と文書を返す
    return loss.data, doc

14. 学習ループ

初回は軽く回すため num_steps=300 にしています。
精度重視なら1000以上にしてください。

損失が緩やかに下がるかを確認します。

# 学習ステップ数(最初は軽め)
num_steps = 300
# 学習曲線記録用
loss_history = []

# 学習ループ
for step in range(num_steps):
    # 1ステップ実行
    loss_value, doc = train_step(step, num_steps)
    # 履歴に追加
    loss_history.append(loss_value)
    # 先頭と20ステップごとにログ出力
    if step == 0 or (step + 1) % 20 == 0:
        print(f'step {step + 1:4d}/{num_steps:4d} | loss {loss_value:.4f} | doc {doc}')

14.1 学習曲線を簡易表示(ASCII)

外部ライブラリなしで、損失の傾向だけ可視化します。

# 損失推移を1行で可視化する簡易関数
def sparkline(values, width=80):
    # 低い値から高い値までのブロック文字
    chars = '▁▂▃▄▅▆▇█'
    # 空配列なら空文字を返す
    if not values:
        return ''
    # 表示幅を超える場合は間引きサンプリング
    if len(values) > width:
        # 何個おきに取るか
        step = len(values) / width
        # 幅に合わせてサンプリング
        sampled = [values[int(i * step)] for i in range(width)]
    else:
        # そのまま使う
        sampled = values

    # 最小値と最大値
    vmin, vmax = min(sampled), max(sampled)
    # 全値同一なら同じ文字で埋める
    if vmax == vmin:
        return chars[0] * len(sampled)

    # 出力文字配列
    out = []
    # 各値を0..7へ正規化して文字に変換
    for v in sampled:
        # 値を文字インデックスへ変換
        idx = int((v - vmin) / (vmax - vmin) * (len(chars) - 1))
        # 対応文字を追加
        out.append(chars[idx])
    # 連結して返す
    return ''.join(out)

# 初回損失を表示
print('first loss:', round(loss_history[0], 4))
# 最終損失を表示
print('last  loss:', round(loss_history[-1], 4))
# 損失推移を1行で表示
print(sparkline(loss_history))

15. 推論(生成)

学習後は BOS から開始し、1文字ずつサンプリングします。

注意: 学習時(teacher forcing)と推論時(自己生成入力)は分布が違うため、
長文では誤差が蓄積しやすいです。

# 学習済みモデルから1サンプル生成する関数
def generate_one(temperature=0.5, max_new_tokens=block_size):
    # 推論用KVキャッシュを初期化
    keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
    # 先頭はBOSから開始
    token_id = BOS
    # 生成結果を格納
    out = []

    # 最大トークン数まで自己回帰生成
    for pos_id in range(max_new_tokens):
        # 現在トークンから次ロジットを取得
        logits = gpt(token_id, pos_id, keys, values)
        # 温度でスケーリングして確率化
        probs = softmax([l / temperature for l in logits])

        # 確率に従って次トークンをサンプリング
        token_id = random.choices(
            range(vocab_size),
            weights=[p.data for p in probs]
        )[0]

        # BOSが出たら終了
        if token_id == BOS:
            break

        # 通常文字なら結果に追加
        out.append(itos[token_id])

    # 文字列として返す
    return ''.join(out)

# サンプルを複数生成して表示
for i in range(20):
    print(f'sample {i+1:2d}: {generate_one(temperature=0.5)}')

16. 注意重みの観察(可視化)

return_attn=True を使い、各時刻での注意分布を表示します。

t は、"位置tのqueryが過去0..tへどれだけ注意したか"を示します。

# 指定文書のAttention行列(時刻ごとの重み列)を取得
def attention_matrix_for_doc(doc, layer=0, head=0):
    # BOSで挟んだトークン列
    tokens = [BOS] + encode(doc) + [BOS]
    # block_size制限
    n = min(block_size, len(tokens) - 1)

    # 層ごとのKVキャッシュ
    keys, values = [[] for _ in range(n_layer)], [[] for _ in range(n_layer)]
    # 行列(行=時刻, 列=参照先時刻)
    attn_mat = []

    # 各時刻でAttention重みを取得
    for pos_id in range(n):
        # 現在トークン
        token_id = tokens[pos_id]
        # 注意重み付きで前向き計算
        logits, records = gpt(token_id, pos_id, keys, values, return_attn=True)
        # ここではロジット自体は使わない
        _ = logits

        # 指定layer/headの重みだけ抽出
        selected = [w for li, h, w in records if li == layer and h == head][0]
        # 行列へ追加
        attn_mat.append(selected)

    # 対象トークン列と行列を返す
    return tokens[:n], attn_mat

# 表示用にIDを文字へ変換する関数
def token_to_str(tid):
    # BOSは見やすいラベルにする
    return '<BOS>' if tid == BOS else itos[tid]

# 可視化対象文書
probe_doc = docs[2]
# Attention行列を取得
token_ids, attn_mat = attention_matrix_for_doc(probe_doc, layer=0, head=0)
# 表示しやすい文字列トークン列へ変換
tokens_str = [token_to_str(t) for t in token_ids]

# 対象文書を表示
print('probe doc:', probe_doc)
# トークン列を表示
print('tokens   :', tokens_str)
# セクション見出し
print('\nattention matrix (layer0 head0):')
# 各時刻の重みベクトルを表示
for t, row in enumerate(attn_mat):
    # 可読性のため小数2桁に整形
    vals = ' '.join(f'{v:0.2f}' for v in row)
    # 時刻ごとの行を出力
    print(f't={t:2d}: {vals}')

17. 計算量の直感

自己注意は系列長 TT に対し、典型的に O(T2d)O(T^2 d) の計算が必要です。

理由は、各位置の query が全位置の keys と相互作用するためです。

このノートの実装は1トークン逐次処理 + KVキャッシュなので、推論時の1ステップは概ね過去長に比例して増えます。ここでは高速化手法を追うのではなく、どこがボトルネックになりやすいかの感覚を作るところまでに留めます。

18. よくある落とし穴と対策

  1. log(0) 問題
  1. 学習が進まない
  1. 推論が空文字ばかり
  1. 実装の形ミス

19. この最小実装で見えて、まだ省いているもの

この notebook で見えているのは、次トークン予測、Transformer ブロック、逆伝播、Adam による更新、サンプリング生成までです。

一方で省いているのは、主に規模と効率のための工夫です。

つまり、この notebook は pretraining の原理をむき出しで読むための最小模型だと考えてください。

20. 追加課題(理解を実装に変える)

  1. n_layer を 1 -> 2 にして、損失と生成品質を比較する
  2. ReLU を GeLU へ差し替える
  3. block_size を増やし、長い依存の学習を試す
  4. 温度に加えて top-k サンプリングを実装する
  5. ミニバッチ化(今は1サンプル逐次)を実装する
  6. 明示的な因果マスク版 Attention を書いて照合する
  7. 名前データ以外の文字列コーパスへ置き換え、トークン分布の差を見る

ここまで進めると、読むだけではなく、事前学習の骨格を自分で動かして確かめられるようになります。

21. 最終まとめ

このノートの核は次の1文です。

事前学習は、次トークン確率を出す関数を、勾配降下で鍛える過程です。

実際に大事なのは、

です。本稿ではそれを最小コードで示しました。ここで骨格が見えていれば、より大きいモデルや高速実装を読むときにも、何が本質で何が工学上の追加なのかを切り分けやすくなります。