機械学習のためのSQL

機械学習で使う SQL は、集計の便利な道具ではなく、学習データの意味そのものを決める装置です。どの時点の情報を特徴量とし、どの未来をラベルとするかが曖昧なままでは、後段のモデルは正しくても全体は壊れます。

SQL の質は、クエリが実行できるかではなく、本番の予測時点を再現できているかで測られます。

予測時点がすべてを決める

学習テーブルを作るときに最初に定めるべきなのは、どの時点で何を知っていて、何をまだ知らないかです。予測時点より後の情報が特徴量へ混ざった瞬間、その学習は未来を見たものになります。

ユーザー属性、行動ログ、購入履歴を結び、特徴量窓とラベル窓を切り分けて 1 行の学習サンプルを作ります。同じ集計でも、時点境界の置き方ひとつで意味が変わります。

基本語彙

最重要なのは、特徴量窓とラベル窓を混ぜないことです。

SQL の正しさは時点境界で決まる

同じ SUMCOUNT でも、snapshot より前だけを集めたのか、後ろまで含めたのかで学習データの意味は一変します。SQL の正しさは構文ではなく時間整合性にあります。

典型的な失敗

もっとも危険なのは未来イベントの混入です。加えて、分割方法をその場しのぎで決めると、再現性のない実験になりやすくなります。

扱う範囲

焦点は、最小の SQL 例でリークと point-in-time 設計を理解することに置きます。本番 ETL 全体ではなく、学習テーブル作成の原理を押さえます。

学習テーブルという 1 行の約束

最終的に欲しいのは、1 行が 1 つの予測時点を表す表です。その 1 行に、予測時点までに利用可能だった情報だけを詰める。この約束が守られると、後続の分割や再学習も一貫します。

土台テーブルを作る

users, events, orders の 3 テーブルを用意します。属性、行動、結果が分かれている構成は実務でも自然で、JOIN の意味を追いやすくします。

import sqlite3
from datetime import datetime, timedelta
from random import Random
from textwrap import dedent

どのテーブルが何の情報を持つかが明確になると、あとで JOIN や集計を書いたときに、どの列がどの時間軸に属しているかを見失いにくくなります。

conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
cur = conn.cursor()

cur.executescript(
    dedent(
        """
        CREATE TABLE users (
            user_id INTEGER PRIMARY KEY,
            signup_date TEXT NOT NULL,
            plan TEXT NOT NULL,
            country TEXT NOT NULL
        );

        CREATE TABLE events (
            event_id INTEGER PRIMARY KEY AUTOINCREMENT,
            user_id INTEGER NOT NULL,
            event_time TEXT NOT NULL,
            event_type TEXT NOT NULL,
            session_seconds INTEGER NOT NULL,
            clicks INTEGER NOT NULL,
            FOREIGN KEY (user_id) REFERENCES users(user_id)
        );

        CREATE TABLE orders (
            order_id INTEGER PRIMARY KEY AUTOINCREMENT,
            user_id INTEGER NOT NULL,
            order_time TEXT NOT NULL,
            amount REAL NOT NULL,
            FOREIGN KEY (user_id) REFERENCES users(user_id)
        );
        """
    )
)
conn.commit()

補助関数

複数のクエリを試す場面では、実行と表示の補助関数があると結果の確認に集中できます。主役は SQL 本体であり、補助関数は表の形を素早く確かめるための道具です。

def q(sql: str, params=()):
    return conn.execute(dedent(sql), params).fetchall()


def show(rows, limit=8):
    rows = list(rows)
    if not rows:
        print("(no rows)")
        return
    cols = rows[0].keys()
    print(" | ".join(cols))
    print("-" * 100)
    for r in rows[:limit]:
        print(" | ".join(str(r[c]) for c in cols))
    if len(rows) > limit:
        print(f"... ({len(rows)} rows total)")

クエリを書いたらすぐ結果を見る。その習慣があると、件数や列名の違和感から早い段階でミスに気づけます。

ユーザー生成、イベント生成、注文生成の順に疑似データを入れると、あとで集計した数値がどこから来たのかを追跡しやすくなります。

rng = Random(42)
start = datetime(2024, 1, 1)
end = datetime(2024, 6, 30)

plans = ["free", "pro", "team"]
plan_weights = [0.62, 0.30, 0.08]
countries = ["JP", "US", "IN", "DE", "FR"]

n_users = 320
for user_id in range(1, n_users + 1):
    signup = start + timedelta(days=rng.randint(0, 80))
    plan = rng.choices(plans, weights=plan_weights, k=1)[0]
    country = rng.choice(countries)
    cur.execute(
        "INSERT INTO users (user_id, signup_date, plan, country) VALUES (?, ?, ?, ?)",
        (user_id, signup.date().isoformat(), plan, country),
    )

conn.commit()

学習前処理としての SQL 基本操作

WHERE, ORDER BY, GROUP BY, JOIN, CASE WHEN は文法項目ではなく、特徴量を作るための部品です。行を絞り、並べ、要約し、別表の情報を結びつけて、学習に渡せる 1 行へ寄せていきます。

for row in q("SELECT user_id, plan, signup_date FROM users"):
    user_id = row["user_id"]
    plan = row["plan"]
    signup_dt = datetime.fromisoformat(row["signup_date"])

    base_events = 20 if plan == "free" else 34 if plan == "pro" else 46
    n_events = max(8, int(rng.gauss(base_events, 6)))

    for _ in range(n_events):
        day_offset = rng.randint(0, 178)
        event_dt = signup_dt + timedelta(days=day_offset, hours=rng.randint(0, 23), minutes=rng.randint(0, 59))
        if event_dt > end:
            continue

        if plan == "free":
            clicks = max(0, int(rng.gauss(2.2, 1.4)))
            session_seconds = max(20, int(rng.gauss(130, 70)))
        elif plan == "pro":
            clicks = max(0, int(rng.gauss(4.8, 2.0)))
            session_seconds = max(30, int(rng.gauss(220, 90)))
        else:
            clicks = max(0, int(rng.gauss(6.5, 2.4)))
            session_seconds = max(40, int(rng.gauss(290, 110)))

        event_type = rng.choices(
            ["page_view", "search", "add_to_cart"],
            weights=[0.62, 0.25, 0.13],
            k=1,
        )[0]

        cur.execute(
            """
            INSERT INTO events (user_id, event_time, event_type, session_seconds, clicks)
            VALUES (?, ?, ?, ?, ?)
            """,
            (user_id, event_dt.isoformat(sep=" "), event_type, session_seconds, clicks),
        )

conn.commit()

各クエリでは、結果の形だけでなく、その集計が snapshot の前なのか後なのかを必ず確認する必要があります。

activity = q(
    """
    SELECT
        u.user_id,
        u.plan,
        u.signup_date,
        COALESCE(COUNT(e.event_id), 0) AS event_count,
        COALESCE(SUM(e.clicks), 0) AS total_clicks
    FROM users u
    LEFT JOIN events e ON u.user_id = e.user_id
    GROUP BY u.user_id, u.plan, u.signup_date
    """
)

for r in activity:
    user_id = r["user_id"]
    plan = r["plan"]
    signup_dt = datetime.fromisoformat(r["signup_date"])
    total_clicks = r["total_clicks"]
    event_count = r["event_count"]

    base_prob = 0.02
    if plan == "pro":
        base_prob += 0.08
    if plan == "team":
        base_prob += 0.12
    base_prob += min(0.35, total_clicks / 380.0)
    base_prob += min(0.20, event_count / 420.0)

    n_orders = 0
    for _ in range(3):
        if rng.random() < base_prob:
            n_orders += 1

    min_order_dt = signup_dt + timedelta(days=1)
    max_order_dt = min(end, signup_dt + timedelta(days=180))
    if min_order_dt > max_order_dt:
        continue

    total_seconds = int((max_order_dt - min_order_dt).total_seconds())

    for _ in range(n_orders):
        offset_sec = rng.randint(0, total_seconds)
        order_dt = min_order_dt + timedelta(seconds=offset_sec)

        mean_amount = 38 if plan == "free" else 74 if plan == "pro" else 130
        amount = max(8, round(rng.gauss(mean_amount, mean_amount * 0.35), 2))

        cur.execute(
            "INSERT INTO orders (user_id, order_time, amount) VALUES (?, ?, ?)",
            (user_id, order_dt.isoformat(sep=" "), amount),
        )

conn.commit()

特徴量候補は、見た目には単なる集計でも、本質的には「予測時点で利用可能かどうか」の判定です。そこを外すと、きれいな SQL でも学習には使えません。

table_counts = q(
    """
    SELECT 'users' AS table_name, COUNT(*) AS n FROM users
    UNION ALL
    SELECT 'events' AS table_name, COUNT(*) AS n FROM events
    UNION ALL
    SELECT 'orders' AS table_name, COUNT(*) AS n FROM orders
    """
)
show(table_counts)

WHEREORDER BY は、行動ログを時間順に読む最初の道具です。個々のイベント列から時間軸が立ち上がると、後の集計が何を圧縮しているのかが見えやすくなります。

rows = q(
    """
    SELECT user_id, event_time, event_type, session_seconds, clicks
    FROM events
    WHERE user_id = ?
    ORDER BY event_time
    LIMIT 10
    """,
    (12,),
)
show(rows)

GROUP BY によって、ばらばらのイベントをユーザー単位の回数や合計へ要約できます。学習モデルは、生ログそのものより、こうした要約量の方を直接扱いやすいことが多くあります。

agg = q(
    """
    SELECT
        user_id,
        COUNT(*) AS n_events,
        SUM(clicks) AS total_clicks,
        AVG(session_seconds) AS avg_session_seconds
    FROM events
    GROUP BY user_id
    HAVING COUNT(*) >= 18
    ORDER BY total_clicks DESC
    LIMIT 10
    """
)
show(agg)

JOINCASE WHEN を使うと、属性と行動を組み合わせた特徴量が作れます。複数テーブルの情報を 1 行へ寄せる作業が、学習テーブル設計の中心です。

asof_ts = '2024-05-01 00:00:00'

joined = q(
    """
    WITH spend_90 AS (
      SELECT
          user_id,
          SUM(amount) AS spend_90d
      FROM orders
      WHERE order_time >= datetime(:asof_ts, '-90 day')
        AND order_time < :asof_ts
      GROUP BY user_id
    )
    SELECT
        u.user_id,
        u.plan,
        CASE WHEN u.plan = 'free' THEN 0 ELSE 1 END AS paid_flag,
        COALESCE(s.spend_90d, 0) AS spend_90d
    FROM users u
    LEFT JOIN spend_90 s ON u.user_id = s.user_id
    ORDER BY spend_90d DESC
    LIMIT 12
    """,
    {"asof_ts": asof_ts},
)
show(joined)

snapshot を予測時点、その前 30 日を特徴量窓、その後 30 日をラベル窓と定めると、SQL は単なる集計ではなく point-in-time のデータ設計になります。

snapshot = '2024-05-01 00:00:00'
horizon_days = 30

feature_sql = """
WITH event_30d AS (
    SELECT
        user_id,
        COUNT(*) AS events_30d,
        SUM(clicks) AS clicks_30d,
        AVG(session_seconds) AS avg_session_30d,
        MAX(event_time) AS last_event_time
    FROM events
    WHERE event_time < :snapshot
      AND event_time >= datetime(:snapshot, '-30 day')
    GROUP BY user_id
),
order_90d AS (
    SELECT
        user_id,
        SUM(amount) AS spend_90d
    FROM orders
    WHERE order_time < :snapshot
      AND order_time >= datetime(:snapshot, '-90 day')
    GROUP BY user_id
),
label_window AS (
    SELECT
        user_id,
        CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END AS purchased_30d
    FROM orders
    WHERE order_time >= :snapshot
      AND order_time < datetime(:snapshot, '+' || :horizon_days || ' day')
    GROUP BY user_id
)
SELECT
    u.user_id,
    u.country,
    CASE WHEN u.plan = 'free' THEN 0 ELSE 1 END AS paid_flag,
    COALESCE(e.events_30d, 0) AS events_30d,
    COALESCE(e.clicks_30d, 0) AS clicks_30d,
    COALESCE(e.avg_session_30d, 0) AS avg_session_30d,
    -- イベントがないユーザーは signup_date を最終行動日として扱う
    CAST((julianday(:snapshot) - julianday(COALESCE(e.last_event_time, u.signup_date))) AS INTEGER) AS days_since_last_event,
    COALESCE(o.spend_90d, 0) AS spend_90d,
    COALESCE(l.purchased_30d, 0) AS label
FROM users u
LEFT JOIN event_30d e ON u.user_id = e.user_id
LEFT JOIN order_90d o ON u.user_id = o.user_id
LEFT JOIN label_window l ON u.user_id = l.user_id
ORDER BY u.user_id
"""

リークしない表とリークした表

未来を含まない clean な特徴量表と、未来を少しだけ見てしまう leaky な特徴量表を比べると、条件 1 つで学習の意味がどれほど壊れるかが分かります。

feature_rows = q(feature_sql, {"snapshot": snapshot, "horizon_days": horizon_days})
show(feature_rows, limit=12)
print('rows:', len(feature_rows))

境界条件がわずかに違うだけでも、ラベル窓の情報が特徴量へ混ざれば、モデルはまだ起きていない購買を先回りで知ってしまいます。

label_dist_sql = f"""
WITH base AS (
{feature_sql}
)
SELECT
    label,
    COUNT(*) AS n_users,
    ROUND(AVG(events_30d), 2) AS avg_events_30d,
    ROUND(AVG(spend_90d), 2) AS avg_spend_90d
FROM base
GROUP BY label
ORDER BY label
"""

label_dist = q(label_dist_sql, {"snapshot": snapshot, "horizon_days": horizon_days})
show(label_dist)

差の原因を時間境界に限定するため、窓幅は同じ 30 日に固定します。何がリークを生んだのかを 1 つずつ切り分けて読むことが大切です。

clean_stats_sql = f"""
WITH base AS (
{feature_sql}
)
SELECT label, ROUND(AVG(events_30d), 2) AS avg_events
FROM base
GROUP BY label
ORDER BY label
"""

leaky_stats_sql = """
WITH event_30d_leaky AS (
    SELECT
        user_id,
        COUNT(*) AS events_30d
    FROM events
    WHERE event_time >= :snapshot
      -- NG: snapshot以降(未来)のイベントを特徴量に含める
      AND event_time < datetime(:snapshot, '+' || :horizon_days || ' day')
    GROUP BY user_id
),
label_window AS (
    SELECT
        user_id,
        CASE WHEN COUNT(*) > 0 THEN 1 ELSE 0 END AS label
    FROM orders
    WHERE order_time >= :snapshot
      AND order_time < datetime(:snapshot, '+' || :horizon_days || ' day')
    GROUP BY user_id
)
SELECT
    COALESCE(l.label, 0) AS label,
    ROUND(AVG(COALESCE(e.events_30d, 0)), 2) AS avg_events
FROM users u
LEFT JOIN event_30d_leaky e ON u.user_id = e.user_id
LEFT JOIN label_window l ON u.user_id = l.user_id
GROUP BY COALESCE(l.label, 0)
ORDER BY label
"""

clean_stats = q(clean_stats_sql, {"snapshot": snapshot, "horizon_days": horizon_days})
leaky_stats = q(leaky_stats_sql, {"snapshot": snapshot, "horizon_days": horizon_days})

print('clean feature mean by label')
show(clean_stats)
print('leaky feature mean by label')
show(leaky_stats)

cleanleaky の差は、SQL の条件が些細に見えても、snapshot をまたいだ瞬間に学習データ全体の意味が変わることを示しています。

fold を作るのは、学習テーブルを再現可能な形で分割できるようにするためです。後段のモデリングだけでなく、前処理の段階から再現性を持たせておく必要があります。

dataset_fold_sql = f"""
WITH base AS (
{feature_sql}
)
SELECT
    user_id,
    paid_flag,
    events_30d,
    clicks_30d,
    avg_session_30d,
    days_since_last_event,
    spend_90d,
    label,
    (user_id % 5) AS fold_id
FROM base
ORDER BY user_id
"""

dataset_with_fold = q(dataset_fold_sql, {"snapshot": snapshot, "horizon_days": horizon_days})
show(dataset_with_fold, limit=12)

SQL は機械学習パイプラインの前処理ではなく、学習品質を支える中核です。ラベル時点と特徴量時点の境界を明示し、未来情報を混ぜないことが、どんな高性能モデルより先に満たされるべき条件です。