Open In Colab   Open in Kaggle

ボーナスチュートリアル5: スパイキングニューロンのための期待値最大化法

第3週、第3日目: 隠れたダイナミクス

Neuromatch Academyによる

コンテンツ作成者: Yicheng Fei(ジェシー・ライブジーの協力を得て)

コンテンツレビュアー: ジョン・バトラー、マット・クラウス、ミーナクシ・コスラ、スピロス・チャヴリス、マイケル・ワスコム

制作編集者: ガガナ B、スピロス・チャヴリス


重要な注意: 本教材はNMA 2020で開発されたものであり、隠れたダイナミクス教材の基準に従って改訂されていません。


謝辞: 本チュートリアルは、ショーン・エスクオラによって元々作成されたコードに基づいています。


チュートリアルの目的

期待値最大化(EM)アルゴリズムは非常に強力で広く使われている最適化手法であり、HMMよりもはるかに一般的です。通常は隠れマルコフモデルの文脈で教えられるため、ここに含めています。

本日紹介したポアソンスパイキングニューロンのネットワークのHMMを実装し、以下を行います:


セットアップ

# @title Install and import feedback gadget


from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch_cn",
            "user_key": "y1x3mpx5",
        },
    ).render()


feedback_prefix = "W3D2_T5_Bonus"
import numpy as np
from scipy import stats
from scipy.optimize import linear_sum_assignment
from collections import namedtuple

import matplotlib.pyplot as plt
from matplotlib import patches

GaussianHMM1D = namedtuple('GaussianHMM1D', ['startprob', 'transmat','means','vars','n_components'])
# @title Figure Settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True

from ipywidgets import widgets, interactive, interact, HBox, Layout,VBox
from IPython.display import HTML
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/NMA2020/nma.mplstyle")
# @title Plotting functions
def plot_spike_train(X, Y, dt):
  """Plots the spike train for cells across trials and overlay the state.

    Args:
      X: (2d numpy array of binary values): The state sequence in a one-hot
                                            representation. (T, states)
      Y: (3d numpy array of floats):        The spike sequence.
                                            (trials, T, C)
      dt (float):                           Interval for a bin.
  """
  n_trials, T, C = Y.shape
  trial_T = T * dt
  fig = plt.figure(figsize=(.7 * (12.8 + 6.4), .7 * 9.6))

  # plot state sequence
  starts = [0] + list(np.diff(X.nonzero()[1]).nonzero()[0])
  stops = list(np.diff(X.nonzero()[1]).nonzero()[0]) + [T]
  states = [X[i + 1].nonzero()[0][0] for i in starts]
  for a, b, i in zip(starts, stops, states):
    rect = patches.Rectangle((a * dt, 0), (b - a) * dt, n_trials * C,
                              facecolor=plt.get_cmap('tab10').colors[i],
                              alpha=0.15)
    plt.gca().add_patch(rect)

  # plot rasters
  for c in range(C):
    if c > 0:
      plt.plot([0, trial_T], [c * n_trials, c * n_trials],
                color=plt.get_cmap('tab10').colors[0])
    for r in range(n_trials):
      tmp = Y[r, :, c].nonzero()[0]
      if len(tmp) > 0:
        plt.plot(np.stack((tmp, tmp)) * dt, (c * n_trials + r + 0.1, c * n_trials + r + .9), color='k')

  ax = plt.gca()
  plt.yticks(np.arange(0, n_trials * C, n_trials),
             labels=np.arange(C, dtype=int))
  plt.xlabel('time (s)', fontsize=16)
  plt.ylabel('Cell number', fontsize=16)
  plt.show(fig)


def plot_lls(lls):
  """Plots log likelihoods at each epoch.
  Args:
    lls (list of floats) log likelihoods at each epoch.
  """
  epochs = len(lls)
  fig, ax = plt.subplots()
  ax.plot(range(epochs) , lls, linewidth=3)
  span = max(lls) - min(lls)
  ax.set_ylim(min(lls) - span * 0.05, max(lls) + span * 0.05)
  plt.xlabel('iteration')
  plt.ylabel('log likelihood')
  plt.show(fig)


def plot_lls_eclls(plot_epochs, save_vals):
  """Plots log likelihoods at each epoch.
  Args:
    plot_epochs (list of ints):  Which epochs were saved to plot.
    save_vals (lists of floats): Different likelihoods from EM for plotting.
  """
  rows = int(np.ceil(min(len(plot_epochs), len(save_vals)) / 3))
  fig, axes = plt.subplots(rows, 3, figsize=(.7 * 6.4 * 3, .7 * 4.8 * rows))
  axes = axes.flatten()

  minll, maxll = np.inf, -np.inf
  for i, (ax, (bs, lls_for_plot, eclls_for_plot)) in enumerate(zip(axes, save_vals)):
    ax.set_xlim([-1.15, 2.15])
    min_val = np.stack((lls_for_plot, eclls_for_plot)).min()
    max_val = np.stack((lls_for_plot, eclls_for_plot)).max()

    ax.plot([0, 0], [min_val, lls_for_plot[bs == 0].item()], '--b')
    ax.plot([1, 1], [min_val, lls_for_plot[bs == 1].item()], '--b')
    ax.set_xticks([0, 1])
    ax.set_xticklabels([f'$\\theta^{plot_epochs[i]}$',
                        f'$\\theta^{plot_epochs[i] + 1}$'])
    ax.tick_params(axis='y')
    ax.tick_params(axis='x')

    ax.plot(bs, lls_for_plot)
    ax.plot(bs, eclls_for_plot)

    if min_val < minll: minll = min_val
    if max_val > maxll: maxll = max_val

    if i % 3 == 0: ax.set_ylabel('log likelihood')
    if i == 4:
      l = ax.legend(ax.lines[-2:], ['LL', 'ECLL'], framealpha=1)
  plt.show(fig)


def plot_learnt_vs_true(L_true, L, A_true, A, dt):
  """Plot and compare the true and learnt parameters.

  Args:
    L_true (numpy array): True L.
    L (numpy array):      Estimated L.
    A_true (numpy array): True A.
    A (numpy array):      Estimated A.
    dt (float):           Bin length.
  """
  C, K = L.shape
  fig = plt.figure(figsize=(8, 4))
  plt.subplot(121)
  plt.plot([0, L_true.max() * 1.05], [0, L_true.max() * 1.05], '--b')
  for i in range(K):
    for c in range(C):
      plt.plot(L_true[c, i], L[c, i], color='C{}'.format(c),
               marker=['o', '*', 'd'][i])  # this line will fail for K > 3
  ax = plt.gca()
  ax.axis('equal')
  plt.xlabel('True firing rate (Hz)')
  plt.ylabel('Inferred firing rate (Hz)')
  xlim, ylim = ax.get_xlim(), ax.get_ylim()
  for c in range(C):
    plt.plot([-10^6], [-10^6], 'o', color='C{}'.format(c))
  for i in range(K):
    plt.plot([-10^6], [-10^6], marker=['o', '*', 'd'][i], c="black")
  l = plt.legend(ax.lines[-C - K:], [f'cell {c + 1}' for c in range(C)] + [f'state {i + 1}' for i in range(K)])
  ax.set_xlim(xlim), ax.set_ylim(ylim)

  plt.subplot(122)
  ymax = np.max(A_true - np.diag(np.diag(A_true))) / dt * 1.05
  plt.plot([0, ymax], [0, ymax], '--b')
  for j in range(K):
    for i in range(K):
      if i == j:
        continue
      plt.plot(A_true[i, j] / dt, A[i, j] / dt, 'o')
  ax = plt.gca()
  ax.axis('equal')
  plt.xlabel('True transition rate (Hz)')
  plt.ylabel('Inferred transition rate (Hz)')
  l = plt.legend(ax.lines[1:], ['state 1 -> 2',
                                      'state 1 -> 3',
                                      'state 2 -> 1',
                                      'state 2 -> 3',
                                      'state 3 -> 1',
                                      'state 3 -> 2'
                                    ])
  plt.show(fig)
# @title Helper functions
def run_em(epochs, Y, psi, A, L, dt):
  """Run EM for the HMM spiking model.

  Args:
    epochs (int):       Number of epochs of EM to run
    Y (numpy 3d array): Tensor of recordings, has shape (n_trials, T, C)
    psi (numpy vector): Initial probabilities for each state
    A (numpy matrix):   Transition matrix, A[i,j] represents the prob to switch
                        from j to i. Has shape (K,K)
    L (numpy matrix):   Poisson rate parameter for different cells.
                        Has shape (C,K)
    dt (float):         Duration of a time bin

  Returns:
    save_vals (lists of floats): Data for later plotting
    lls (list of flots):         ll Before each EM step
    psi (numpy vector):          Estimated initial probabilities for each state
    A (numpy matrix):            Estimated transition matrix, A[i,j] represents
                                 the prob to switch from j to i. Has shape (K,K)
    L (numpy matrix):            Estimated Poisson rate parameter for different
                                 cells. Has shape (C,K)
  """
  save_vals = []
  lls = []
  for e in range(epochs):

    # Run E-step
    ll, gamma, xi = e_step(Y, psi, A, L, dt)
    lls.append(ll)  # log the data log likelihood for current cycle

    if e % print_every == 0: print(f'epoch: {e:3d}, ll = {ll}')  # log progress
    # Run M-step
    psi_new, A_new, L_new = m_step(gamma, xi, dt)

    """Booking keeping for later plotting
    Calculate the difference of parameters for later
    interpolation/extrapolation
    """
    dp, dA, dL = psi_new - psi, A_new - A, L_new - L
    # Calculate LLs and ECLLs for later plotting
    if e in plot_epochs:
      b_min = -min([np.min(psi[dp > 0] / dp[dp > 0]),
                    np.min(A[dA > 0] / dA[dA > 0]),
                    np.min(L[dL > 0] / dL[dL > 0])])
      b_max = -max([np.max(psi[dp < 0] / dp[dp < 0]),
                    np.max(A[dA < 0] / dA[dA < 0]),
                    np.max(L[dL < 0] / dL[dL < 0])])
      b_min = np.max([.99 * b_min, b_lims[0]])
      b_max = np.min([.99 * b_max, b_lims[1]])
      bs = np.linspace(b_min, b_max, num_plot_vals)
      bs = sorted(list(set(np.hstack((bs, [0, 1])))))
      bs = np.array(bs)
      lls_for_plot = []
      eclls_for_plot = []
      for i, b in enumerate(bs):
        ll = e_step(Y, psi + b * dp, A + b * dA, L + b * dL, dt)[0]
        lls_for_plot.append(ll)
        rate = (L + b * dL) * dt
        ecll = ((gamma[:, 0] @ np.log(psi + b * dp) +
                  (xi * np.log(A + b * dA)).sum(axis=(-1, -2, -3)) +
                  (gamma * stats.poisson(rate).logpmf(Y[..., np.newaxis]).sum(-2)
                  ).sum(axis=(-1, -2))).mean() / T / dt)
        eclls_for_plot.append(ecll)
        if b == 0:
          diff_ll = ll - ecll
      lls_for_plot = np.array(lls_for_plot)
      eclls_for_plot = np.array(eclls_for_plot) + diff_ll
      save_vals.append((bs, lls_for_plot, eclls_for_plot))
    # return new parameter
    psi, A, L = psi_new, A_new, L_new

  ll = e_step(Y, psi, A, L, dt)[0]
  lls.append(ll)
  print(f'epoch: {epochs:3d}, ll = {ll}')
  return save_vals, lls, psi, A, L

セクション0: はじめに

#@title Video 1: Introduction
# Insert the ID of the corresponding youtube video
from IPython.display import YouTubeVideo
video = YouTubeVideo(id="ceQXN0OUaFo", width=854, height=480, fs=1)
print("Video available at https://youtu.be/" + video.id)
video
# @title Submit your feedback
content_review(f"{feedback_prefix}_Introduction_Video")

セクション1: ポアソンスパイキングニューロンネットワークのHMM

#@title Video 2: HMM for Poisson spiking neurons case study
# Insert the ID of the corresponding youtube video
from IPython.display import YouTubeVideo
video = YouTubeVideo(id="Wb8mf5chmyI", width=854, height=480, fs=1)
print("Video available at https://youtu.be/" + video.id)
video
# @title Submit your feedback
content_review(f"{feedback_prefix}_HMM_for_Poisson_spiking_neurons_Video")

ノイズの多い神経または行動の測定値が与えられた場合、神経科学者としては、時間とともに変化する観測されていない潜在変数を推定したいことがよくあります。視床中継ニューロンは2つの異なるモードで発火します:スパイクが1つずつ発生するトニックモードと、複数の活動電位が連続して発生する「バーストモード」です。これらのモードは、感覚受容器から皮質への情報伝達を異なる方法で符号化していると考えられています。異なる分子機構であるT型カルシウムチャネルがニューロンのモードを切り替えますが、生きたサルの脳内でこれを測定するのは非常に困難です。しかし、統計的アプローチにより、行動中のサルで測定可能なスパイク活動のみからこれらのカルシウムチャネルの隠れた状態を復元することが可能です。

ここでは、その問題の簡略化バージョンに取り組みます。

イントロ講義で述べた定式化を考えましょう。
CC個のニューロンのネットワークがKK状態間を切り替えます。ニューロンccは状態iiで発火率λic\lambda_i^cを持ちます。状態間の遷移はK×KK\times Kの遷移行列AijA_{ij}と、時刻t=1t=1での長さKKの初期確率ベクトルψ\psiで表されます。

tt番目の時間ビンにおける細胞ccのスパイク数をytcy_t^cとします。


以下の演習(1と2)およびチュートリアルでは、

モデルの定義とデータ生成

まず、隠れマルコフ連鎖からランダムな状態列を生成し、n_frozen_trials個の異なる試行で各細胞のスパイク列を生成します。ここでは、すべての細胞が先ほど生成した同じ基底状態列を使うと仮定します。

提案

  1. 以下の「モデルとシミュレーションパラメータ」および「真のモデルの初期化」の2つのセクションを実行して、真のモデルとパラメータを定義してください。パラメータを確認し、将来不明な変数があればこれらのセルに戻って確認してください。

  2. 与えられた状態列をすべての細胞のすべての時刻に対応するスパイク率に変換するコードを実行し、提供されたコードでスパイク列を可視化してください。

モデルとシミュレーションパラメータ

# model and data parameters
C = 5  # number of cells
K = 3  # number of states
dt = 0.002  # seconds
trial_T = 2.0  # seconds
n_frozen_trials = 20  # used to plot multiple trials with the same state sequence
n_trials = 300  # number of trials (each has it's own state sequence)

# for random data
max_firing_rate = 50  # Hz
max_transition_rate = 3  # Hz

# needed to plot LL and ECLL for every M-step
# **This substantially slows things down!!**
num_plot_vals = 10  # resolution of the plot (this is the expensive part)
b_lims = (-1, 2)  # lower limit on graph (b = 0 is start-of-M-step LL; b = 1 is end-of-M-step LL)
plot_epochs = list(range(9))  # list of epochs to plot

真のモデルの初期化

np.random.seed(101)
T = round(trial_T / dt)
ts = np.arange(T)

# initial state distribution
psi = np.arange(1, K + 1)
psi = psi / psi.sum()

# off-diagonal transition rates sampled uniformly
A = np.random.rand(K, K) * max_transition_rate * dt
A = (1. - np.eye(K)) * A
A = A + np.diag(1 - A.sum(1))

# hand-crafted firing rates make good plots
L = np.array([
    [.02, .8, .37],
    [1., .7, .1],
    [.92, .07, .5],
    [.25, .42, .75],
    [.15, .2, .85]
]) * max_firing_rate  # (C,K)

# Save true parameters for comparison later
psi_true = psi
A_true = A
L_true = L

凍結された状態列でデータ生成とプロット

状態列[0,1,1,3,2,...]が与えられた場合、まず各状態を「ワンホット」符号化に変換します。例えば、状態数が5の場合、状態0のワンホット符号は[1,0,0,0,0]で、状態3の符号は[0,0,0,1,0]です。長さTの列があるとすると、この列のワンホット符号化Xfは形状(T,K)を持ちます。

np.random.seed(101)
# sample n_frozen_trials state sequences
Xf = np.zeros(T, dtype=int)
Xf[0] = (psi.cumsum() > np.random.rand()).argmax()
for t in range(1, T):
  Xf[t] = (A[Xf[t - 1],:].cumsum() > np.random.rand()).argmax()

# switch to one-hot encoding of the state
Xf = np.eye(K, dtype=int)[Xf]  # (T,K)

# get the Y values
Rates = np.squeeze(L @ Xf[..., None]) * dt  # (T,C)

Rates = np.tile(Rates, [n_frozen_trials, 1, 1]) # (n_trials, T, C)
Yf = stats.poisson(Rates).rvs()

plot_spike_train(Xf, Yf, dt)

EM学習のためのデータ生成

前回のデータセットは可視化のために同じ状態系列で生成しました。今回は、各試行ごとにランダムに生成された系列を持つ観測値をn_trials回分生成しましょう。

np.random.seed(101)
# sample n_trials state sequences
X = np.zeros((n_trials, T), dtype=int)
X[:, 0] = (psi_true.cumsum(0)[:, None] > np.random.rand(n_trials)).argmax(0)
for t in range(1, T):
  X[:, t] = (A_true[X[:, t - 1], :].T.cumsum(0) > np.random.rand(n_trials)).argmax(0)

# switch to one-hot encoding of the state
one_hot = np.eye(K)[np.array(X).reshape(-1)]
X = one_hot.reshape(list(X.shape) + [K])

# get the Y values
Y = stats.poisson(np.squeeze(L_true @ X[..., None]) * dt).rvs()  # (n_trials, T, C)
print("Y has shape: (n_trial={},T={},C={})".format(*Y.shape))

セクション2: HMMのためのEMアルゴリズム

#@title Video 3: EM Tutorial
# Insert the ID of the corresponding youtube video
from IPython.display import YouTubeVideo
video = YouTubeVideo(id="umU4wUWlKvg", width=854, height=480, fs=1)
print("Video available at https://youtu.be/" + video.id)
video
# @title Submit your feedback
content_review(f"{feedback_prefix}_EM_tutorial_Video")

データの尤度を最大化するパラメータの最適値を見つけることは、すべての潜在変数x1:Tx_{1:T}を積分する必要があるため、実質的に不可能です。必要な時間はTTに対して指数関数的に増加します。そこで代替手法として、期待値最大化(EM)アルゴリズムを用います。これはEステップとMステップを交互に繰り返し、各EMサイクル後にデータの尤度が減少しない(通常は増加する)ことが保証されています。

このセクションでは、HMMのEMアルゴリズムを簡単に復習し、以下を示します。

Eステップ: 順方向-逆方向アルゴリズム

順方向パスでは、xtx_tと現在および過去のデータY1:tY_{1:t}の結合確率、すなわち順方向確率ai(t):=p(xt=i,Y1:t)a_i(t):=p(x_t=i,Y_{1:t})を以下のように再帰的に計算します。

ai(t)=p(ytxt=i)jAjiaj(t1)a_i(t) = p(y_t|x_t=i)\sum_j A_{ji} a_j(t-1)

導入部とは異なり、ここでのAjiA_{ji}状態jjから状態iiへの遷移確率を意味します。

逆方向パスでは、逆方向確率bi(t):=pθ(Yt+1:Txt=i)b_i(t):=p_{\theta}(Y_{t+1:T}|x_t=i)を計算します。これは現在の状態xtx_tが与えられたときに将来のすべてのデータ点を観測する尤度です。bi(t)b_i(t)の再帰式は以下の通りです。

bi(t)=jpθ(yt+1xt+1=j)bj(t+1)Aijb_i(t) = \sum_j p_{\theta}(y_{t+1}|x_{t+1}=j)b_j(t+1)A_{ij}

過去と未来の情報を組み合わせると、単一およびペアの周辺分布は以下のように与えられます。

γi(t):=pθ(xt=iY1:T)=ai(t)bi(t)pθ(Y1:T)\gamma_{i}(t):=p_{\theta}\left(x_{t}=i | Y_{1: T}\right)=\frac{a_{i}(t) b_{i}(t)}{p_{\theta}\left(Y_{1: T}\right)} ξij(t)=pθ(xt=i,xt+1=jY1:T)=bj(t+1)pθ(yt+1xt+1=j)Aijai(t)pθ(Y1:T)\xi_{i j}(t) = p_{\theta}(x_t=i,x_{t+1}=j|Y_{1:T}) =\frac{b_{j}(t+1)p_{\theta}\left(y_{t+1} | x_{t+1}=j\right) A_{i j} a_{i}(t)}{p_{\theta}\left(Y_{1: T}\right)}

ここでpθ(Y1:T)=iai(T)p_{\theta}(Y_{1:T})=\sum_i a_i(T)です。

Mステップ

HMMのMステップは閉形式解があります。まず、新しい遷移行列は以下のように与えられます。

Aij=t=1T1ξij(t)t=1T1γi(t)A_{ij} =\frac{\sum_{t=1}^{T-1} \xi_{i j}(t)}{\sum_{t=1}^{T-1} \gamma_{i}(t)}

これは期待される経験的遷移確率です。
新しい初期確率と放出モデルのパラメータも、単一およびペアの周辺分布に基づく経験的値として与えられます:

\begin{align}
ψi\psi_i &= 1Ntrialsγi(1)\frac{1}{N}\sum_{trials}\gamma_i(1) \
λic\lambda_{i}^{c} &= \frac{\sum_{t} \gamma_{i}(t) y_{t}^{c}}{\sum_{t} \gamma_{i}(t) d t} \end{align}

Eステップ: 順方向および逆方向アルゴリズム

(任意)

このセクションでは、順方向-逆方向アルゴリズムのコードを読み、すべての試行を一度に計算することでnumpyで効率的に計算を実装する方法を理解します。

順方向および逆方向の再帰式をより簡潔な形で書き直しましょう:

\begin{eqnarray}
aita_i^t &=& jAjiojtajt1\sum_j A_{ji}o_j^t a_j^{t-1}\
bitb^t_i &=& \sum_j A_{ij} o_j^{t+1}b_j^{t+1} \text{, ここで } o_j^{t}=p(y_{t}|x_{t}=j) \end{eqnarray}

逆方向再帰を例にとりましょう。実際には、試行は互いに独立なので全試行をまとめて扱います。試行インデックスllを再帰式に加えると、逆方向再帰は次のようになります。

blit=jAijoljt+1bljt+1b^t_{li} = \sum_j A_{ij} o_{lj}^{t+1}b_{lj}^{t+1}

手元にあるのは:

ここでNは試行数を表します。

これら3つの配列のインデックスのサイズと意味は一致しません。AAの1次元目はiioobbllなので、単純に掛け合わせることはできません。しかし、olt+1o^{t+1}_{l\cdot}blt+1b^{t+1}_{l\cdot}を1行の行列として見なすことで、逆方向の式を次のように書き換えられます。

blit=jAijol1jt+1bl1jt+1b^t_{li} = \sum_j A_{ij} o_{l1j}^{t+1}b_{l1j}^{t+1}

これで3つの配列を要素ごとに掛けて最後の次元で合計できます。

numpyでは、配列に新しい次元を挿入したい位置にNoneを使ってインデックス指定することでこれを実現できます。例えば、サイズ(N,T,K)bについて、b[:,t,:]は形状(N,K)b[:,t,None,:](N,1,K)b[:,t,:,None](N,K,1)となります。

したがって、逆方向再帰の計算は以下のように実装できます。

b[:,t,:] = (A * o[:,t+1,None,:] * b[:,t+1,None,:]).sum(-1)

上記のトリックに加え、この演習では数値安定性のため対数スケールで作業します。

提案: 順方向および逆方向の再帰のコードを確認してみてください。

def e_step(Y, psi, A, L, dt):
  """Calculate the E-step for the HMM spiking model.

  Args:
    Y (numpy 3d array): tensor of recordings, has shape (n_trials, T, C)
    psi (numpy vector): initial probabilities for each state
    A (numpy matrix):   transition matrix, A[i,j] represents the prob to
                        switch from i to j. Has shape (K,K)
    L (numpy matrix):   Poisson rate parameter for different cells.
                        Has shape (C,K)
    dt (float):         Bin length

  Returns:
    ll (float):             data log likelihood
    gamma (numpy 3d array): singleton marginal distribution.
                            Has shape (n_trials, T, K)
    xi (numpy 4d array):    pairwise marginal distribution for adjacent
                            nodes . Has shape (n_trials, T-1, K, K)
  """
  n_trials = Y.shape[0]
  T = Y.shape[1]
  K = psi.size
  log_a = np.zeros((n_trials, T, K))
  log_b = np.zeros((n_trials, T, K))

  log_A = np.log(A)
  log_obs = stats.poisson(L * dt).logpmf(Y[..., None]).sum(-2)  # n_trials, T, K

  # forward pass
  log_a[:, 0] = log_obs[:, 0] + np.log(psi)
  for t in range(1, T):
    tmp = log_A + log_a[:, t - 1, : ,None]  # (n_trials, K,K)
    maxtmp = tmp.max(-2)  # (n_trials,K)
    log_a[:, t] = (log_obs[:, t] + maxtmp +
                    np.log(np.exp(tmp - maxtmp[:, None]).sum(-2)))

  # backward pass
  for t in range(T - 2, -1, -1):
    tmp = log_A + log_b[:, t + 1, None] + log_obs[:, t + 1, None]
    maxtmp = tmp.max(-1)
    log_b[:, t] = maxtmp + np.log(np.exp(tmp - maxtmp[..., None]).sum(-1))

  # data log likelihood
  maxtmp = log_a[:, -1].max(-1)
  ll = np.log(np.exp(log_a[:, -1] - maxtmp[:, None]).sum(-1)) + maxtmp

  # singleton and pairwise marginal distributions
  gamma = np.exp(log_a + log_b - ll[:, None, None])
  xi = np.exp(log_a[:, :-1, :, None] + (log_obs + log_b)[:, 1:, None] +
              log_A - ll[:, None, None, None])

  return ll.mean() / T / dt, gamma, xi
#@title Video 4: Implement the M-step
# Insert the ID of the corresponding youtube video
from IPython.display import YouTubeVideo
video = YouTubeVideo(id="H4GGTg_9BaE", width=854, height=480, fs=1)
print("Video available at https://youtu.be/" + video.id)
video
# @title Submit your feedback
content_review(f"{feedback_prefix}_Implement_the_M_step_Video")

コーディング演習1: Mステップの実装

この演習では、前述の閉形式解を用いてこのHMMのMステップを完成させます。

提案

  1. 単一周辺分布の経験的カウントとして新しい初期確率を計算する
ψi=1Ntrialsγi(1)\psi_i = \frac{1}{N}\sum_{trials}\gamma_i(1)
  1. 試行の次元が追加されていることを忘れず、すべての試行で平均をとる

参考:

新しい遷移行列は周辺分布からの遷移イベントの経験的カウントとして計算されます。

Aij=t=1T1ξij(t)t=1T1γi(t)A_{ij} =\frac{\sum_{t=1}^{T-1} \xi_{i j}(t)}{\sum_{t=1}^{T-1} \gamma_{i}(t)}

各セルおよび各状態の新しい発火率は以下の通りです。

λic=tγi(t)ytctγi(t)dt\lambda_{i}^{c}=\frac{\sum_{t} \gamma_{i}(t) y_{t}^{c}}{\sum_{t} \gamma_{i}(t) d t}
def m_step(gamma, xi, dt):
  """Calculate the M-step updates for the HMM spiking model.

  Args:
    gamma (numpy 3d array): singleton marginal distribution.
                            Has shape (n_trials, T, K)
    xi (numpy 3d array): Tensor of recordings, has shape (n_trials, T, C)
    dt (float):         Duration of a time bin

  Returns:
    psi_new (numpy vector): Updated initial probabilities for each state
    A_new (numpy matrix):   Updated transition matrix, A[i,j] represents the
                            prob. to switch from j to i. Has shape (K,K)
    L_new (numpy matrix):   Updated Poisson rate parameter for different
                            cells. Has shape (C,K)
  """
  raise NotImplementedError("`m_step` need to be implemented")
  ############################################################################
  # Insert your code here to:
  #    Calculate the new prior probabilities in each state at time 0
  #    Hint: Take the first time step and average over all trials
  ###########################################################################
  psi_new = ...
  # Make sure the probabilities are normalized
  psi_new /= psi_new.sum()

  # Calculate new transition matrix
  A_new = xi.sum(axis=(0, 1)) / gamma[:, :-1].sum(axis=(0, 1))[:, np.newaxis]
  # Calculate new firing rates
  L_new = (np.swapaxes(Y, -1, -2) @ gamma).sum(axis=0) / gamma.sum(axis=(0, 1)) / dt
  return psi_new, A_new, L_new

解答を見る$

# @title Submit your feedback
content_review(f"{feedback_prefix}_Implement_M_step_Exercise")
#@title Video 5: Running and plotting EM
# Insert the ID of the corresponding youtube video
from IPython.display import YouTubeVideo
video = YouTubeVideo(id="6UTsXxE3hG0", width=854, height=480, fs=1)
print("Video available at https://youtu.be/" + video.id)
video
# @title Submit your feedback
content_review(f"{feedback_prefix}_Running_and_plotting_EM_Video")

EMの実行

パラメータの初期化

np.random.seed(101)
# number of EM steps
epochs = 9
print_every = 1

# initial state distribution
psi = np.arange(1, K + 1)
psi = psi / psi.sum()

# off-diagonal transition rates sampled uniformly
A = np.ones((K, K)) * max_transition_rate * dt / 2
A = (1 - np.eye(K)) * A
A = A + np.diag(1 - A.sum(1))

# firing rates sampled uniformly
L = np.random.rand(C, K) * max_firing_rate
# LL for true vs. initial parameters
print(f'LL for true 𝜃:    {e_step(Y, psi_true, A_true, L_true, dt)[0]}')
print(f'LL for initial 𝜃: {e_step(Y, psi, A, L, dt)[0]}\n')

# Run EM
save_vals, lls, psi, A, L = run_em(epochs, Y, psi, A, L, dt)
# EM doesn't guarantee the order of learnt latent states are the same as that of true model
# so we need to sort learnt parameters

# Compare all true and estimated latents across cells
cost_mat = np.sum((L_true[..., np.newaxis] - L[:, np.newaxis])**2, axis=0)
true_ind, est_ind = linear_sum_assignment(cost_mat)

psi = psi[est_ind]
A = A[est_ind]
A = A[:, est_ind]
L = L[:, est_ind]

トレーニング過程と学習済みモデルのプロット

EMの進行状況のプロット

これで以下ができます。

# Plot the log likelihood after each epoch of EM
with plt.xkcd():
  plot_lls(lls)
# For each saved epoch, plot the log likelihood and expected complete log likelihood
# for the initial and final parameter values
with plt.xkcd():
  plot_lls_eclls(plot_epochs, save_vals)

学習済みパラメータと真のパラメータの比較プロット

ここでは、(ソートした)学習済みパラメータと真のパラメータをプロットし、すべてのパラメータを正しく復元できたかを確認します。

# Compare true and learnt parameters
with plt.xkcd():
  plot_learnt_vs_true(L_true, L, A_true, A, dt)