ボーナスチュートリアル5: スパイキングニューロンのための期待値最大化法
第3週、第3日目: 隠れたダイナミクス
Neuromatch Academyによる
コンテンツ作成者: Yicheng Fei(ジェシー・ライブジーの協力を得て)
コンテンツレビュアー: ジョン・バトラー、マット・クラウス、ミーナクシ・コスラ、スピロス・チャヴリス、マイケル・ワスコム
制作編集者: ガガナ B、スピロス・チャヴリス
重要な注意: 本教材はNMA 2020で開発されたものであり、隠れたダイナミクス教材の基準に従って改訂されていません。
謝辞: 本チュートリアルは、ショーン・エスクオラによって元々作成されたコードに基づいています。
チュートリアルの目的
期待値最大化(EM)アルゴリズムは非常に強力で広く使われている最適化手法であり、HMMよりもはるかに一般的です。通常は隠れマルコフモデルの文脈で教えられるため、ここに含めています。
本日紹介したポアソンスパイキングニューロンのネットワークのHMMを実装し、以下を行います:
- フォワード・バックワードアルゴリズムの実装
- EステップとMステップの完成
- EMアルゴリズムを用いた例題のパラメータ学習
- EMアルゴリズムがデータの尤度を単調に増加させる直感を得る
セットアップ
# @title Install and import feedback gadget
from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
return DatatopsContentReviewContainer(
"", # No text prompt
notebook_section,
{
"url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
"name": "neuromatch_cn",
"user_key": "y1x3mpx5",
},
).render()
feedback_prefix = "W3D2_T5_Bonus"
import numpy as np
from scipy import stats
from scipy.optimize import linear_sum_assignment
from collections import namedtuple
import matplotlib.pyplot as plt
from matplotlib import patches
GaussianHMM1D = namedtuple('GaussianHMM1D', ['startprob', 'transmat','means','vars','n_components'])
# @title Figure Settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True
from ipywidgets import widgets, interactive, interact, HBox, Layout,VBox
from IPython.display import HTML
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/NMA2020/nma.mplstyle")
# @title Plotting functions
def plot_spike_train(X, Y, dt):
"""Plots the spike train for cells across trials and overlay the state.
Args:
X: (2d numpy array of binary values): The state sequence in a one-hot
representation. (T, states)
Y: (3d numpy array of floats): The spike sequence.
(trials, T, C)
dt (float): Interval for a bin.
"""
n_trials, T, C = Y.shape
trial_T = T * dt
fig = plt.figure(figsize=(.7 * (12.8 + 6.4), .7 * 9.6))
# plot state sequence
starts = [0] + list(np.diff(X.nonzero()[1]).nonzero()[0])
stops = list(np.diff(X.nonzero()[1]).nonzero()[0]) + [T]
states = [X[i + 1].nonzero()[0][0] for i in starts]
for a, b, i in zip(starts, stops, states):
rect = patches.Rectangle((a * dt, 0), (b - a) * dt, n_trials * C,
facecolor=plt.get_cmap('tab10').colors[i],
alpha=0.15)
plt.gca().add_patch(rect)
# plot rasters
for c in range(C):
if c > 0:
plt.plot([0, trial_T], [c * n_trials, c * n_trials],
color=plt.get_cmap('tab10').colors[0])
for r in range(n_trials):
tmp = Y[r, :, c].nonzero()[0]
if len(tmp) > 0:
plt.plot(np.stack((tmp, tmp)) * dt, (c * n_trials + r + 0.1, c * n_trials + r + .9), color='k')
ax = plt.gca()
plt.yticks(np.arange(0, n_trials * C, n_trials),
labels=np.arange(C, dtype=int))
plt.xlabel('time (s)', fontsize=16)
plt.ylabel('Cell number', fontsize=16)
plt.show(fig)
def plot_lls(lls):
"""Plots log likelihoods at each epoch.
Args:
lls (list of floats) log likelihoods at each epoch.
"""
epochs = len(lls)
fig, ax = plt.subplots()
ax.plot(range(epochs) , lls, linewidth=3)
span = max(lls) - min(lls)
ax.set_ylim(min(lls) - span * 0.05, max(lls) + span * 0.05)
plt.xlabel('iteration')
plt.ylabel('log likelihood')
plt.show(fig)
def plot_lls_eclls(plot_epochs, save_vals):
"""Plots log likelihoods at each epoch.
Args:
plot_epochs (list of ints): Which epochs were saved to plot.
save_vals (lists of floats): Different likelihoods from EM for plotting.
"""
rows = int(np.ceil(min(len(plot_epochs), len(save_vals)) / 3))
fig, axes = plt.subplots(rows, 3, figsize=(.7 * 6.4 * 3, .7 * 4.8 * rows))
axes = axes.flatten()
minll, maxll = np.inf, -np.inf
for i, (ax, (bs, lls_for_plot, eclls_for_plot)) in enumerate(zip(axes, save_vals)):
ax.set_xlim([-1.15, 2.15])
min_val = np.stack((lls_for_plot, eclls_for_plot)).min()
max_val = np.stack((lls_for_plot, eclls_for_plot)).max()
ax.plot([0, 0], [min_val, lls_for_plot[bs == 0].item()], '--b')
ax.plot([1, 1], [min_val, lls_for_plot[bs == 1].item()], '--b')
ax.set_xticks([0, 1])
ax.set_xticklabels([f'$\\theta^{plot_epochs[i]}$',
f'$\\theta^{plot_epochs[i] + 1}$'])
ax.tick_params(axis='y')
ax.tick_params(axis='x')
ax.plot(bs, lls_for_plot)
ax.plot(bs, eclls_for_plot)
if min_val < minll: minll = min_val
if max_val > maxll: maxll = max_val
if i % 3 == 0: ax.set_ylabel('log likelihood')
if i == 4:
l = ax.legend(ax.lines[-2:], ['LL', 'ECLL'], framealpha=1)
plt.show(fig)
def plot_learnt_vs_true(L_true, L, A_true, A, dt):
"""Plot and compare the true and learnt parameters.
Args:
L_true (numpy array): True L.
L (numpy array): Estimated L.
A_true (numpy array): True A.
A (numpy array): Estimated A.
dt (float): Bin length.
"""
C, K = L.shape
fig = plt.figure(figsize=(8, 4))
plt.subplot(121)
plt.plot([0, L_true.max() * 1.05], [0, L_true.max() * 1.05], '--b')
for i in range(K):
for c in range(C):
plt.plot(L_true[c, i], L[c, i], color='C{}'.format(c),
marker=['o', '*', 'd'][i]) # this line will fail for K > 3
ax = plt.gca()
ax.axis('equal')
plt.xlabel('True firing rate (Hz)')
plt.ylabel('Inferred firing rate (Hz)')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
for c in range(C):
plt.plot([-10^6], [-10^6], 'o', color='C{}'.format(c))
for i in range(K):
plt.plot([-10^6], [-10^6], marker=['o', '*', 'd'][i], c="black")
l = plt.legend(ax.lines[-C - K:], [f'cell {c + 1}' for c in range(C)] + [f'state {i + 1}' for i in range(K)])
ax.set_xlim(xlim), ax.set_ylim(ylim)
plt.subplot(122)
ymax = np.max(A_true - np.diag(np.diag(A_true))) / dt * 1.05
plt.plot([0, ymax], [0, ymax], '--b')
for j in range(K):
for i in range(K):
if i == j:
continue
plt.plot(A_true[i, j] / dt, A[i, j] / dt, 'o')
ax = plt.gca()
ax.axis('equal')
plt.xlabel('True transition rate (Hz)')
plt.ylabel('Inferred transition rate (Hz)')
l = plt.legend(ax.lines[1:], ['state 1 -> 2',
'state 1 -> 3',
'state 2 -> 1',
'state 2 -> 3',
'state 3 -> 1',
'state 3 -> 2'
])
plt.show(fig)
# @title Helper functions
def run_em(epochs, Y, psi, A, L, dt):
"""Run EM for the HMM spiking model.
Args:
epochs (int): Number of epochs of EM to run
Y (numpy 3d array): Tensor of recordings, has shape (n_trials, T, C)
psi (numpy vector): Initial probabilities for each state
A (numpy matrix): Transition matrix, A[i,j] represents the prob to switch
from j to i. Has shape (K,K)
L (numpy matrix): Poisson rate parameter for different cells.
Has shape (C,K)
dt (float): Duration of a time bin
Returns:
save_vals (lists of floats): Data for later plotting
lls (list of flots): ll Before each EM step
psi (numpy vector): Estimated initial probabilities for each state
A (numpy matrix): Estimated transition matrix, A[i,j] represents
the prob to switch from j to i. Has shape (K,K)
L (numpy matrix): Estimated Poisson rate parameter for different
cells. Has shape (C,K)
"""
save_vals = []
lls = []
for e in range(epochs):
# Run E-step
ll, gamma, xi = e_step(Y, psi, A, L, dt)
lls.append(ll) # log the data log likelihood for current cycle
if e % print_every == 0: print(f'epoch: {e:3d}, ll = {ll}') # log progress
# Run M-step
psi_new, A_new, L_new = m_step(gamma, xi, dt)
"""Booking keeping for later plotting
Calculate the difference of parameters for later
interpolation/extrapolation
"""
dp, dA, dL = psi_new - psi, A_new - A, L_new - L
# Calculate LLs and ECLLs for later plotting
if e in plot_epochs:
b_min = -min([np.min(psi[dp > 0] / dp[dp > 0]),
np.min(A[dA > 0] / dA[dA > 0]),
np.min(L[dL > 0] / dL[dL > 0])])
b_max = -max([np.max(psi[dp < 0] / dp[dp < 0]),
np.max(A[dA < 0] / dA[dA < 0]),
np.max(L[dL < 0] / dL[dL < 0])])
b_min = np.max([.99 * b_min, b_lims[0]])
b_max = np.min([.99 * b_max, b_lims[1]])
bs = np.linspace(b_min, b_max, num_plot_vals)
bs = sorted(list(set(np.hstack((bs, [0, 1])))))
bs = np.array(bs)
lls_for_plot = []
eclls_for_plot = []
for i, b in enumerate(bs):
ll = e_step(Y, psi + b * dp, A + b * dA, L + b * dL, dt)[0]
lls_for_plot.append(ll)
rate = (L + b * dL) * dt
ecll = ((gamma[:, 0] @ np.log(psi + b * dp) +
(xi * np.log(A + b * dA)).sum(axis=(-1, -2, -3)) +
(gamma * stats.poisson(rate).logpmf(Y[..., np.newaxis]).sum(-2)
).sum(axis=(-1, -2))).mean() / T / dt)
eclls_for_plot.append(ecll)
if b == 0:
diff_ll = ll - ecll
lls_for_plot = np.array(lls_for_plot)
eclls_for_plot = np.array(eclls_for_plot) + diff_ll
save_vals.append((bs, lls_for_plot, eclls_for_plot))
# return new parameter
psi, A, L = psi_new, A_new, L_new
ll = e_step(Y, psi, A, L, dt)[0]
lls.append(ll)
print(f'epoch: {epochs:3d}, ll = {ll}')
return save_vals, lls, psi, A, L
セクション0: はじめに
#@title Video 1: Introduction
# Insert the ID of the corresponding youtube video
from IPython.display import YouTubeVideo
video = YouTubeVideo(id="ceQXN0OUaFo", width=854, height=480, fs=1)
print("Video available at https://youtu.be/" + video.id)
video
# @title Submit your feedback
content_review(f"{feedback_prefix}_Introduction_Video")
セクション1: ポアソンスパイキングニューロンネットワークのHMM
#@title Video 2: HMM for Poisson spiking neurons case study
# Insert the ID of the corresponding youtube video
from IPython.display import YouTubeVideo
video = YouTubeVideo(id="Wb8mf5chmyI", width=854, height=480, fs=1)
print("Video available at https://youtu.be/" + video.id)
video
# @title Submit your feedback
content_review(f"{feedback_prefix}_HMM_for_Poisson_spiking_neurons_Video")
ノイズの多い神経または行動の測定値が与えられた場合、神経科学者としては、時間とともに変化する観測されていない潜在変数を推定したいことがよくあります。視床中継ニューロンは2つの異なるモードで発火します:スパイクが1つずつ発生するトニックモードと、複数の活動電位が連続して発生する「バーストモード」です。これらのモードは、感覚受容器から皮質への情報伝達を異なる方法で符号化していると考えられています。異なる分子機構であるT型カルシウムチャネルがニューロンのモードを切り替えますが、生きたサルの脳内でこれを測定するのは非常に困難です。しかし、統計的アプローチにより、行動中のサルで測定可能なスパイク活動のみからこれらのカルシウムチャネルの隠れた状態を復元することが可能です。
ここでは、その問題の簡略化バージョンに取り組みます。
イントロ講義で述べた定式化を考えましょう。
個のニューロンのネットワークが状態間を切り替えます。ニューロンは状態で発火率を持ちます。状態間の遷移はの遷移行列と、時刻での長さの初期確率ベクトルで表されます。
番目の時間ビンにおける細胞のスパイク数をとします。
以下の演習(1と2)およびチュートリアルでは、
- 、のモデルのインスタンスを定義する
- このモデルからデータセットを生成する
- (演習1)このHMMのMステップを実装する
- EMを実行してすべてのパラメータを推定する
- 学習の尤度曲線をプロットする
- 期待完全対数尤度とデータ対数尤度をプロットする
- 学習したパラメータと真のパラメータを比較する
モデルの定義とデータ生成
まず、隠れマルコフ連鎖からランダムな状態列を生成し、n_frozen_trials個の異なる試行で各細胞のスパイク列を生成します。ここでは、すべての細胞が先ほど生成した同じ基底状態列を使うと仮定します。
提案
-
以下の「モデルとシミュレーションパラメータ」および「真のモデルの初期化」の2つのセクションを実行して、真のモデルとパラメータを定義してください。パラメータを確認し、将来不明な変数があればこれらのセルに戻って確認してください。
-
与えられた状態列をすべての細胞のすべての時刻に対応するスパイク率に変換するコードを実行し、提供されたコードでスパイク列を可視化してください。
モデルとシミュレーションパラメータ
# model and data parameters
C = 5 # number of cells
K = 3 # number of states
dt = 0.002 # seconds
trial_T = 2.0 # seconds
n_frozen_trials = 20 # used to plot multiple trials with the same state sequence
n_trials = 300 # number of trials (each has it's own state sequence)
# for random data
max_firing_rate = 50 # Hz
max_transition_rate = 3 # Hz
# needed to plot LL and ECLL for every M-step
# **This substantially slows things down!!**
num_plot_vals = 10 # resolution of the plot (this is the expensive part)
b_lims = (-1, 2) # lower limit on graph (b = 0 is start-of-M-step LL; b = 1 is end-of-M-step LL)
plot_epochs = list(range(9)) # list of epochs to plot
真のモデルの初期化
np.random.seed(101)
T = round(trial_T / dt)
ts = np.arange(T)
# initial state distribution
psi = np.arange(1, K + 1)
psi = psi / psi.sum()
# off-diagonal transition rates sampled uniformly
A = np.random.rand(K, K) * max_transition_rate * dt
A = (1. - np.eye(K)) * A
A = A + np.diag(1 - A.sum(1))
# hand-crafted firing rates make good plots
L = np.array([
[.02, .8, .37],
[1., .7, .1],
[.92, .07, .5],
[.25, .42, .75],
[.15, .2, .85]
]) * max_firing_rate # (C,K)
# Save true parameters for comparison later
psi_true = psi
A_true = A
L_true = L
凍結された状態列でデータ生成とプロット
状態列[0,1,1,3,2,...]が与えられた場合、まず各状態を「ワンホット」符号化に変換します。例えば、状態数が5の場合、状態0のワンホット符号は[1,0,0,0,0]で、状態3の符号は[0,0,0,1,0]です。長さTの列があるとすると、この列のワンホット符号化Xfは形状(T,K)を持ちます。
np.random.seed(101)
# sample n_frozen_trials state sequences
Xf = np.zeros(T, dtype=int)
Xf[0] = (psi.cumsum() > np.random.rand()).argmax()
for t in range(1, T):
Xf[t] = (A[Xf[t - 1],:].cumsum() > np.random.rand()).argmax()
# switch to one-hot encoding of the state
Xf = np.eye(K, dtype=int)[Xf] # (T,K)
# get the Y values
Rates = np.squeeze(L @ Xf[..., None]) * dt # (T,C)
Rates = np.tile(Rates, [n_frozen_trials, 1, 1]) # (n_trials, T, C)
Yf = stats.poisson(Rates).rvs()
plot_spike_train(Xf, Yf, dt)
EM学習のためのデータ生成
前回のデータセットは可視化のために同じ状態系列で生成しました。今回は、各試行ごとにランダムに生成された系列を持つ観測値をn_trials回分生成しましょう。
np.random.seed(101)
# sample n_trials state sequences
X = np.zeros((n_trials, T), dtype=int)
X[:, 0] = (psi_true.cumsum(0)[:, None] > np.random.rand(n_trials)).argmax(0)
for t in range(1, T):
X[:, t] = (A_true[X[:, t - 1], :].T.cumsum(0) > np.random.rand(n_trials)).argmax(0)
# switch to one-hot encoding of the state
one_hot = np.eye(K)[np.array(X).reshape(-1)]
X = one_hot.reshape(list(X.shape) + [K])
# get the Y values
Y = stats.poisson(np.squeeze(L_true @ X[..., None]) * dt).rvs() # (n_trials, T, C)
print("Y has shape: (n_trial={},T={},C={})".format(*Y.shape))
セクション2: HMMのためのEMアルゴリズム
#@title Video 3: EM Tutorial
# Insert the ID of the corresponding youtube video
from IPython.display import YouTubeVideo
video = YouTubeVideo(id="umU4wUWlKvg", width=854, height=480, fs=1)
print("Video available at https://youtu.be/" + video.id)
video
# @title Submit your feedback
content_review(f"{feedback_prefix}_EM_tutorial_Video")
データの尤度を最大化するパラメータの最適値を見つけることは、すべての潜在変数を積分する必要があるため、実質的に不可能です。必要な時間はに対して指数関数的に増加します。そこで代替手法として、期待値最大化(EM)アルゴリズムを用います。これはEステップとMステップを交互に繰り返し、各EMサイクル後にデータの尤度が減少しない(通常は増加する)ことが保証されています。
このセクションでは、HMMのEMアルゴリズムを簡単に復習し、以下を示します。
- 順方向および逆方向確率との再帰式
- データを観測した後の単一およびペアの周辺分布の表現:および
- データ尤度を増加させるの更新値の閉形式解
Eステップ: 順方向-逆方向アルゴリズム
順方向パスでは、と現在および過去のデータの結合確率、すなわち順方向確率を以下のように再帰的に計算します。
導入部とは異なり、ここでのは状態から状態への遷移確率を意味します。
逆方向パスでは、逆方向確率を計算します。これは現在の状態が与えられたときに将来のすべてのデータ点を観測する尤度です。の再帰式は以下の通りです。
過去と未来の情報を組み合わせると、単一およびペアの周辺分布は以下のように与えられます。
ここでです。
Mステップ
HMMのMステップは閉形式解があります。まず、新しい遷移行列は以下のように与えられます。
これは期待される経験的遷移確率です。
新しい初期確率と放出モデルのパラメータも、単一およびペアの周辺分布に基づく経験的値として与えられます:
\begin{align}
&= \
&= \frac{\sum_{t} \gamma_{i}(t) y_{t}^{c}}{\sum_{t} \gamma_{i}(t) d t}
\end{align}
Eステップ: 順方向および逆方向アルゴリズム
(任意)
このセクションでは、順方向-逆方向アルゴリズムのコードを読み、すべての試行を一度に計算することでnumpyで効率的に計算を実装する方法を理解します。
順方向および逆方向の再帰式をより簡潔な形で書き直しましょう:
\begin{eqnarray}
&=& \
&=& \sum_j A_{ij} o_j^{t+1}b_j^{t+1} \text{, ここで } o_j^{t}=p(y_{t}|x_{t}=j)
\end{eqnarray}
逆方向再帰を例にとりましょう。実際には、試行は互いに独立なので全試行をまとめて扱います。試行インデックスを再帰式に加えると、逆方向再帰は次のようになります。
手元にあるのは:
A: サイズ(K,K)の行列- : サイズ
(N,K)の配列で、全試行のある時刻における対数データ尤度 - : サイズ
(N,K)の配列で、全試行のある時刻における逆方向確率
ここでNは試行数を表します。
これら3つの配列のインデックスのサイズと意味は一致しません。の1次元目は、とはなので、単純に掛け合わせることはできません。しかし、とを1行の行列として見なすことで、逆方向の式を次のように書き換えられます。
これで3つの配列を要素ごとに掛けて最後の次元で合計できます。
numpyでは、配列に新しい次元を挿入したい位置にNoneを使ってインデックス指定することでこれを実現できます。例えば、サイズ(N,T,K)のbについて、b[:,t,:]は形状(N,K)、b[:,t,None,:]は(N,1,K)、b[:,t,:,None]は(N,K,1)となります。
したがって、逆方向再帰の計算は以下のように実装できます。
b[:,t,:] = (A * o[:,t+1,None,:] * b[:,t+1,None,:]).sum(-1)
上記のトリックに加え、この演習では数値安定性のため対数スケールで作業します。
提案: 順方向および逆方向の再帰のコードを確認してみてください。
def e_step(Y, psi, A, L, dt):
"""Calculate the E-step for the HMM spiking model.
Args:
Y (numpy 3d array): tensor of recordings, has shape (n_trials, T, C)
psi (numpy vector): initial probabilities for each state
A (numpy matrix): transition matrix, A[i,j] represents the prob to
switch from i to j. Has shape (K,K)
L (numpy matrix): Poisson rate parameter for different cells.
Has shape (C,K)
dt (float): Bin length
Returns:
ll (float): data log likelihood
gamma (numpy 3d array): singleton marginal distribution.
Has shape (n_trials, T, K)
xi (numpy 4d array): pairwise marginal distribution for adjacent
nodes . Has shape (n_trials, T-1, K, K)
"""
n_trials = Y.shape[0]
T = Y.shape[1]
K = psi.size
log_a = np.zeros((n_trials, T, K))
log_b = np.zeros((n_trials, T, K))
log_A = np.log(A)
log_obs = stats.poisson(L * dt).logpmf(Y[..., None]).sum(-2) # n_trials, T, K
# forward pass
log_a[:, 0] = log_obs[:, 0] + np.log(psi)
for t in range(1, T):
tmp = log_A + log_a[:, t - 1, : ,None] # (n_trials, K,K)
maxtmp = tmp.max(-2) # (n_trials,K)
log_a[:, t] = (log_obs[:, t] + maxtmp +
np.log(np.exp(tmp - maxtmp[:, None]).sum(-2)))
# backward pass
for t in range(T - 2, -1, -1):
tmp = log_A + log_b[:, t + 1, None] + log_obs[:, t + 1, None]
maxtmp = tmp.max(-1)
log_b[:, t] = maxtmp + np.log(np.exp(tmp - maxtmp[..., None]).sum(-1))
# data log likelihood
maxtmp = log_a[:, -1].max(-1)
ll = np.log(np.exp(log_a[:, -1] - maxtmp[:, None]).sum(-1)) + maxtmp
# singleton and pairwise marginal distributions
gamma = np.exp(log_a + log_b - ll[:, None, None])
xi = np.exp(log_a[:, :-1, :, None] + (log_obs + log_b)[:, 1:, None] +
log_A - ll[:, None, None, None])
return ll.mean() / T / dt, gamma, xi
#@title Video 4: Implement the M-step
# Insert the ID of the corresponding youtube video
from IPython.display import YouTubeVideo
video = YouTubeVideo(id="H4GGTg_9BaE", width=854, height=480, fs=1)
print("Video available at https://youtu.be/" + video.id)
video
# @title Submit your feedback
content_review(f"{feedback_prefix}_Implement_the_M_step_Video")
コーディング演習1: Mステップの実装
この演習では、前述の閉形式解を用いてこのHMMのMステップを完成させます。
提案
- 単一周辺分布の経験的カウントとして新しい初期確率を計算する
- 試行の次元が追加されていることを忘れず、すべての試行で平均をとる
参考:
新しい遷移行列は周辺分布からの遷移イベントの経験的カウントとして計算されます。
各セルおよび各状態の新しい発火率は以下の通りです。
def m_step(gamma, xi, dt):
"""Calculate the M-step updates for the HMM spiking model.
Args:
gamma (numpy 3d array): singleton marginal distribution.
Has shape (n_trials, T, K)
xi (numpy 3d array): Tensor of recordings, has shape (n_trials, T, C)
dt (float): Duration of a time bin
Returns:
psi_new (numpy vector): Updated initial probabilities for each state
A_new (numpy matrix): Updated transition matrix, A[i,j] represents the
prob. to switch from j to i. Has shape (K,K)
L_new (numpy matrix): Updated Poisson rate parameter for different
cells. Has shape (C,K)
"""
raise NotImplementedError("`m_step` need to be implemented")
############################################################################
# Insert your code here to:
# Calculate the new prior probabilities in each state at time 0
# Hint: Take the first time step and average over all trials
###########################################################################
psi_new = ...
# Make sure the probabilities are normalized
psi_new /= psi_new.sum()
# Calculate new transition matrix
A_new = xi.sum(axis=(0, 1)) / gamma[:, :-1].sum(axis=(0, 1))[:, np.newaxis]
# Calculate new firing rates
L_new = (np.swapaxes(Y, -1, -2) @ gamma).sum(axis=0) / gamma.sum(axis=(0, 1)) / dt
return psi_new, A_new, L_new
# @title Submit your feedback
content_review(f"{feedback_prefix}_Implement_M_step_Exercise")
#@title Video 5: Running and plotting EM
# Insert the ID of the corresponding youtube video
from IPython.display import YouTubeVideo
video = YouTubeVideo(id="6UTsXxE3hG0", width=854, height=480, fs=1)
print("Video available at https://youtu.be/" + video.id)
video
# @title Submit your feedback
content_review(f"{feedback_prefix}_Running_and_plotting_EM_Video")
EMの実行
パラメータの初期化
np.random.seed(101)
# number of EM steps
epochs = 9
print_every = 1
# initial state distribution
psi = np.arange(1, K + 1)
psi = psi / psi.sum()
# off-diagonal transition rates sampled uniformly
A = np.ones((K, K)) * max_transition_rate * dt / 2
A = (1 - np.eye(K)) * A
A = A + np.diag(1 - A.sum(1))
# firing rates sampled uniformly
L = np.random.rand(C, K) * max_firing_rate
# LL for true vs. initial parameters
print(f'LL for true 𝜃: {e_step(Y, psi_true, A_true, L_true, dt)[0]}')
print(f'LL for initial 𝜃: {e_step(Y, psi, A, L, dt)[0]}\n')
# Run EM
save_vals, lls, psi, A, L = run_em(epochs, Y, psi, A, L, dt)
# EM doesn't guarantee the order of learnt latent states are the same as that of true model
# so we need to sort learnt parameters
# Compare all true and estimated latents across cells
cost_mat = np.sum((L_true[..., np.newaxis] - L[:, np.newaxis])**2, axis=0)
true_ind, est_ind = linear_sum_assignment(cost_mat)
psi = psi[est_ind]
A = A[est_ind]
A = A[:, est_ind]
L = L[:, est_ind]
トレーニング過程と学習済みモデルのプロット
EMの進行状況のプロット
これで以下ができます。
- トレーニング中の尤度のプロット
- Mステップの対数尤度と期待完全対数尤度(ECLL)のプロットにより、EMの動作とECLLの凸性の直感を得る
- 学習済みパラメータと真のパラメータの比較プロット
# Plot the log likelihood after each epoch of EM
with plt.xkcd():
plot_lls(lls)
# For each saved epoch, plot the log likelihood and expected complete log likelihood
# for the initial and final parameter values
with plt.xkcd():
plot_lls_eclls(plot_epochs, save_vals)
学習済みパラメータと真のパラメータの比較プロット
ここでは、(ソートした)学習済みパラメータと真のパラメータをプロットし、すべてのパラメータを正しく復元できたかを確認します。
# Compare true and learnt parameters
with plt.xkcd():
plot_learnt_vs_true(L_true, L, A_true, A, dt)