Open In Colab   Open in Kaggle

チュートリアル 1: 離散状態の最適制御

ボーナス: 最適制御

Neuromatch Academyによる

コンテンツ作成者: Zhengwei Wu, Itzel Olivos Castillo, Shreya Saxena, Xaq Pitkow

コンテンツレビュアー: Karolina Stosio, Roozbeh Farhoodi, Saeed Salehi, Ella Batty, Spiros Chavlis, Matt Krause, Michael Waskom, Melisa Maidana Capitan

制作編集者: Gagana B, Spiros Chavlis


チュートリアルの目的

推定所要時間: 60分

このチュートリアルでは、二値制御タスクを実装します。これは釣りを表す部分観測マルコフ決定過程(POMDP)です。エージェント(あなた)は、魚の群れがどこにいるかを直接観察できない状態で、2つの釣り場から報酬を得ようとします(はい、魚の群れは「スクール」と呼ばれます!)。これにより、世界はHidden Dynamicsの日と同様に隠れマルコフモデル(HMM)となります。いつどこで魚を捕まえたかに基づいて、過去の観測に基づく魚の位置の事後分布、すなわち信念を更新し続けます。あなたは、サイドを切り替えるコストを最小限に抑えつつ、最も多くの魚を獲得するように位置を制御すべきです。

これまでに確率的ダイナミクス、潜在状態、観測について学びました。最初の演習はこれらの復習が中心です。ここで新たに行動を導入し、制御、効用、方策の新しい概念に基づきます。この一般的な構造は、動物が情報を収集し、環境について推論し、最大の利益をもたらす行動を選択できる知覚-行動ループを含むため、脳の計算の基礎モデルを提供します。ニューロンがこれらの計算をどのように実装するかは別の問題であり、このレッスンでは扱いません。

このチュートリアルであなたは:

# @title Tutorial slides
# @markdown These are the slides for all videos in this tutorial.
from IPython.display import IFrame
link_id = "8j5rs"
print(f"If you want to download the slides: https://osf.io/download/{link_id}/")
IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render", width=854, height=480)

セットアップ

# @title Install and import feedback gadget


from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch_cn",
            "user_key": "y1x3mpx5",
        },
    ).render()


feedback_prefix = "W3D3_T1"
# Imports
import numpy as np
from math import isclose
import matplotlib.pyplot as plt
# @title Figure Settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True

import ipywidgets as widgets
from IPython.display import HTML
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")
# @title Plotting Functions

def plot_fish(fish_state, ax=None, show=True):
  """
  Plot the fish dynamics (states across time)
  """
  T = len(fish_state)

  offset = 3

  if not ax:
    fig, ax = plt.subplots(1, 1, figsize=(12, 3.5))

  x = np.arange(0, T, 1)
  y = offset * (fish_state*2 - 1)

  ax.plot(y, color='cornflowerblue', markersize=10, linewidth=3.0, zorder=0)
  ax.fill_between(x, y, color='cornflowerblue', alpha=.3)

  ax.set_xlabel('time')
  ax.set_ylabel('fish location')

  ax.set_xlim([0, T])
  ax.set_xticks([])
  ax.xaxis.set_label_coords(1.05, .54)

  ax.set_ylim([-(offset+.5), offset+.5])
  ax.set_yticks([-offset, offset])
  ax.set_yticklabels(['left', 'right'])

  ax.spines['bottom'].set_position('center')
  if show:
    plt.show()


def plot_measurement(measurement, ax=None, show=True):
  """
  Plot the measurements
  """
  T = len(measurement)

  rel_pos = 3
  red_y = []
  blue_y = []
  for idx, value in enumerate(measurement):
    if value == 0:
      blue_y.append([idx, -rel_pos])
    else:
      red_y.append([idx, rel_pos])

  red_y = np.asarray(red_y)
  blue_y = np.asarray(blue_y)

  if not ax:
    fig, ax = plt.subplots(1, 1, figsize=(12, 3.5))

  if len(red_y) > 0:
    ax.plot(red_y[:, 0], red_y[:, 1], '*', markersize=8, color='crimson')

  if len(blue_y) > 0:
    ax.plot(blue_y[:, 0], blue_y[:, 1], '*', markersize=8, color='royalblue')

  ax.set_xlabel('time', fontsize=18)
  ax.set_ylabel('Caught fish?')

  ax.set_xlim([0, T])
  ax.set_xticks([])
  ax.xaxis.set_label_coords(1.05, .54)

  ax.set_ylim([-rel_pos - .5, rel_pos + .5])
  ax.set_yticks([-rel_pos, rel_pos])
  ax.set_yticklabels(['no', 'yes!'])

  ax.spines['bottom'].set_position('center')
  if show:
    plt.show()


def plot_act_loc(loc, act, ax_loc=None, show=True):
  """
  Plot the action and location of T time points
  """
  T = len(act)

  if not ax_loc:
    fig, ax_loc = plt.subplots(1, 1, figsize=(12, 2.5))

  loc = loc*2 - 1
  act_down = []
  act_up = []
  for t in range(1, T):
    if loc[t-1] == -1 and loc[t] == 1:
      act_up.append([t - 0.5, 0])
    if loc[t-1] == 1 and loc[t] == -1:
      act_down.append([t - 0.5, 0])

  act_down = np.array(act_down)
  act_up = np.array(act_up)

  ax_loc.plot(loc, 'g.-', markersize=8, linewidth=5)

  if len(act_down) > 0:
    ax_loc.plot(act_down[:, 0], act_down[:, 1], 'rv', markersize=18, zorder=10, label='switch')

  if len(act_up) > 0:
    ax_loc.plot(act_up[:, 0], act_up[:, 1], 'r^', markersize=18, zorder=10)

  ax_loc.set_xlabel('time')
  ax_loc.set_ylabel('Your state')

  ax_loc.set_xlim([0, T])
  ax_loc.set_xticks([])
  ax_loc.xaxis.set_label_coords(1.05, .54)

  if len(act_down) > 0:
    ax_loc.legend(loc="upper right")
  elif len(act_down) == 0 and len(act_up) > 0:
    ax_loc.plot(act_up[:, 0], act_up[:, 1], 'r^', markersize=18, zorder=10, label='switch')
    ax_loc.legend(loc="upper right")

  ax_loc.set_ylim([-1.1, 1.1])
  ax_loc.set_yticks([-1, 1])

  ax_loc.tick_params(axis='both', which='major')
  ax_loc.set_yticklabels(['left', 'right'])

  ax_loc.spines['bottom'].set_position('center')

  if show:
    plt.show()


def plot_belief(belief, ax1=None, choose_policy=None, show=True):
  """
  Plot the belief dynamics of T time points
  """

  T = belief.shape[1]

  if not ax1:
      fig, ax1 = plt.subplots(1, 1, figsize=(12, 2.5))

  ax1.plot(belief[1, :], color='midnightblue', markersize=10, linewidth=3.0)

  ax1.set_xlabel('time')
  ax1.set_ylabel('Belief (right)')

  ax1.set_xlim([0, T])
  ax1.set_xticks([])
  ax1.xaxis.set_label_coords(1.05, 0.05)

  ax1.set_yticks([0, 1])
  ax1.set_ylim([0, 1.1])

  labels = [item.get_text() for item in ax1.get_yticklabels()]
  ax1.set_yticklabels(['    0', '    1'])

  """
  if choose_policy == "threshold":
    ax2 = ax1.twinx()
    ax2.plot(time_range, threshold * np.ones(time_range.shape), 'r--')
    ax2.plot(time_range, (1 - threshold) * np.ones(time_range.shape), 'c--')
    ax2.set_yticks([threshold, 1 - threshold])
    ax2.set_ylim([0, 1.1])
    ax2.tick_params(axis='both', which='major', labelsize=18)
    labels = [item.get_text() for item in ax2.get_yticklabels()]
    labels[0] = 'threshold to switch \n from left to right'
    labels[-1] = 'threshold to switch \n from right to left'
    ax2.set_yticklabels(labels)
  """
  if show:
    plt.show()


def plot_dynamics(belief, loc, act, meas, fish_state, choose_policy):
  """
  Plot the dynamics of T time points
  """
  if choose_policy == 'threshold':
    fig, [ax0, ax_bel, ax_loc, ax1] = plt.subplots(4, 1, figsize=(12, 9))
    plot_fish(fish_state, ax=ax0, show=False)
    plot_belief(belief, ax1=ax_bel, show=False)
    plot_measurement(meas, ax=ax1, show=False)
    plot_act_loc(loc, act, ax_loc=ax_loc)
  else:
    fig, [ax0, ax_bel, ax1] = plt.subplots(3, 1, figsize=(12, 7))
    plot_fish(fish_state, ax=ax0, show=False)
    plot_belief(belief, ax1=ax_bel, show=False)
    plot_measurement(meas, ax=ax1, show=False)

  plt.tight_layout()
  plt.show()


def belief_histogram(belief, bins=100):
  """
  Plot the histogram of belief states
  """
  fig, ax = plt.subplots(1, 1, figsize=(8, 6))
  ax.hist(belief, bins)
  ax.set_xlabel('belief', fontsize=18)
  ax.set_ylabel('count', fontsize=18)
  plt.show()


def plot_value_threshold(threshold_array, value_array):
  """
  Helper function to plot the value function and threshold
  """
  yrange = np.max(value_array) - np.min(value_array)
  star_loc = np.argmax(value_array)

  fig_, ax = plt.subplots(1, 1, figsize=(8, 6))
  ax.plot(threshold_array, value_array, 'b')
  ax.vlines(threshold_array[star_loc],
            min(value_array) - yrange * .1, max(value_array),
            colors='red', ls='--')
  ax.plot(threshold_array[star_loc],
          value_array[star_loc],
          '*', color='crimson',
          markersize=20)

  ax.set_ylim([np.min(value_array) - yrange * .1,
               np.max(value_array) + yrange * .1])
  ax.set_title(f'threshold vs value with switching cost c = {cost_sw:.2f}',
               fontsize=20)
  ax.set_xlabel('threshold', fontsize=16)
  ax.set_ylabel('value', fontsize=16)
  plt.show()
# @title Helper Functions

# To generate a binomial with fixed "noise",
# we generate a fist sequence of T numbers uniformly at random
# this sequence can be changed later with different size T
np.random.seed(42)
init_T = 100
rnd_tele = np.random.uniform(0, 1, init_T)
rnd_high_rwd = np.random.uniform(0, 1, init_T)
rnd_low_rwd = np.random.uniform(0, 1, init_T)


def get_randomness(T):

  global rnd_tele
  global rnd_high_rwd
  global rnd_low_rwd

  rnd_tele = np.random.uniform(0, 1, T)
  rnd_high_rwd = np.random.uniform(0, 1, T)
  rnd_low_rwd = np.random.uniform(0, 1, T)


def binomial_tele(p, T):
  if len(rnd_tele) < T:
    get_randomness(T)

  return np.array([1 if p > rnd_tele[i] else 0 for i in range(T)])


# Need a better name
def getRandomness(p, T):

  global rnd_tele
  global rnd_high_rwd
  global rnd_low_rwd

  rnd_tele = np.random.uniform(0, 1, T)
  rnd_high_rwd = np.random.uniform(0, 1, T)
  rnd_low_rwd = np.random.uniform(0, 1, T)

  return [binomial_tele(p, T), rnd_high_rwd, rnd_low_rwd]


class ExcerciseError(AssertionError):
  pass


class binaryHMM():

  def __init__(self, params, T, fish_initial=0, loc_initial=0):
    self.params = params
    self.fish_initial = fish_initial
    self.loc_initial = loc_initial
    self.T = T

  def fish_dynamics(self):
    """
    fish state dynamics according to telegraph process

    Returns:
      fish_state (numpy array of int)
    """

    p_stay, _, _, _ = self.params
    fish_state = np.zeros(self.T, int)  # 0: left side and 1: right side

    # initialization
    fish_state[0] = self.fish_initial
    tele_operations = binomial_tele(p_stay, self.T) # 0: switch and 1: stay

    for t in range(1, self.T):
      # we use logical operation NOT XOR to determine the next state
      fish_state[t] = int(not(fish_state[t-1] ^ tele_operations[t]))

    return fish_state

  def generate_process_lazy(self):
    """
    fish dynamics and rewards if you always stay in the initial location
    without changing sides

    Returns:
      fish_state (numpy array of int): locations of the fish
      loc (numpy array of int): left or right site, 0 for left, and 1 for right
      rwd (numpy array of binary): whether a fish was caught or not
    """

    _, p_low_rwd, p_high_rwd, _ = self.params

    fish_state = self.fish_dynamics()
    rwd = np.zeros(self.T, int)  # 0: no food, 1: get food

    for t in range(0, self.T):
      # new measurement
      if fish_state[t] != self.loc_initial:
        rwd[t] = 1 if p_low_rwd > rnd_low_rwd[t] else 0
      else:
        rwd[t] = 1 if p_high_rwd > rnd_high_rwd[t] else 0

      # rwd[t] = binomial(1, p_rwd_vector[(fish_state[t] == loc[t]) * 1])
    return fish_state, self.loc_initial*np.ones(self.T), rwd


class binaryHMM_belief(binaryHMM):

  def __init__(self, params, T,
               fish_initial=0, loc_initial=1,
               choose_policy='threshold'):

      binaryHMM.__init__(self, params, T, fish_initial, loc_initial)
      self.choose_policy = choose_policy

  def generate_process(self):
    """
    fish dynamics and measurements based on the chosen policy

    Returns:
      belief (numpy array of float): belief on the states of the two sites
      act (numpy array of string): actions over time
      loc (numpy array of int): left or right site
      measurement (numpy array of binary): whether a reward is obtained
      fish_state (numpy array of int): fish locations
    """

    p_stay, low_rew_p, high_rew_p, threshold = self.params
    fish_state = self.fish_dynamics()     # 0: left side; 1: right side
    loc = np.zeros(self.T, int)                # 0: left side, 1: right side
    measurement = np.zeros(self.T, int)        # 0: no food, 1: get food
    act = np.empty(self.T, dtype='object')     # "stay", or "switch"
    belief = np.zeros((2, self.T), float)      # the probability that the fish is on the left (1st element)
                                          # or on the right (2nd element),
                                          # the beliefs on the two boxes sum up to be 1

    rew_prob = np.array([low_rew_p, high_rew_p])

    # initialization
    loc[0] = self.loc_initial
    measurement[0] = 0
    belief_0 = np.random.random(1)[0]
    belief[:, 0] = np.array([belief_0, 1 - belief_0])
    act[0] = self.policy(threshold, belief[:, 0], loc[0])

    for t in range(1, self.T):
      if act[t - 1] == "stay":
          loc[t] = loc[t - 1]
      else:
          loc[t] = int(not(loc[t - 1] ^ 0))

      # new measurement
      # measurement[t] = binomial(1, rew_prob[(fish_state[t] == loc[t]) * 1])
      if fish_state[t] != loc[t]:
          measurement[t] = 1 if low_rew_p > rnd_low_rwd[t] else 0
      else:
          measurement[t] = 1 if high_rew_p > rnd_high_rwd[t] else 0

      belief[0, t] = self.belief_update(belief[0, t - 1] , loc[t],
                                        measurement[t], p_stay,
                                        high_rew_p, low_rew_p)
      belief[1, t] = 1 - belief[0, t]

      act[t] = self.policy(threshold, belief[:, t], loc[t])

    return belief, loc, act, measurement, fish_state

  def policy(self, threshold, belief, loc):
    """
    chooses policy based on whether it is lazy policy
        or a threshold-based policy

    Args:
      threshold (float): the threshold of belief on the current site,
          when the belief is lower than the threshold, switch side
      belief (numpy array of float): the belief on the two sites
      loc (int) : the location of the agent

    Returns:
      act (string): "stay" or "switch"
    """
    if self.choose_policy == "threshold":
      act = policy_threshold(threshold, belief, loc)
    if self.choose_policy == "lazy":
      act = policy_lazy(belief, loc)

    return act

  def belief_update(self, belief_past, loc, measurement, p_stay,
                    high_rew_p, low_rew_p):
    """
    using PAST belief on the LEFT box, CURRENT location and
        and measurement to update belief
    """
    rew_prob_matrix = np.array([[1 - high_rew_p, high_rew_p],
                                [1 - low_rew_p, low_rew_p]])

    # update belief posterior, p(s[t] | measurement(0-t), act(0-t-1))
    belief_0 = (belief_past * p_stay  + (1 - belief_past) * (1 - p_stay)) *\
                            rew_prob_matrix[(loc + 1) // 2, measurement]
    belief_1 = ((1 - belief_past) * p_stay + belief_past * (1 - p_stay)) *\
                            rew_prob_matrix[1-(loc + 1) // 2, measurement]

    belief_0 = belief_0 / (belief_0 + belief_1)

    return belief_0


def policy_lazy(belief, loc):
  """
  This function is a lazy policy where stay is also taken
  """
  act = "stay"

  return act


def test_policy_threshold():
  well_done = True
  for loc in [-1, 1]:
    threshold = 0.4
    belief = np.array([.2, .3])
    if policy_threshold(threshold, belief, loc) != "switch":
      raise ExcerciseError("'policy_threshold' function is not correctly implemented!")
  for loc in [1, -1]:
    threshold = 0.6
    belief = np.array([.7, .8])
    if policy_threshold(threshold, belief, loc) != "stay":
      raise ExcerciseError("'policy_threshold' function is not correctly implemented!")
  print("Well Done!")


def test_policy_threshold():
  for loc in [-1, 1]:
    threshold = 0.4
    belief = np.ones(2) * (threshold + 0.1)
    belief[(loc + 1) // 2] = threshold - 0.1

    if policy_threshold(threshold, belief, loc) != "switch":
      raise ExcerciseError("'policy_threshold' function is not correctly implemented!")
    if policy_threshold(threshold, belief, -1 * loc) != "stay":
      raise ExcerciseError("'policy_threshold' function is not correctly implemented!")

  print("Well Done!")


def test_value_function():
  measurement = np.array([0, 0, 0, 1, 0, 0, 0, 0, 1, 1])
  act = np.array(["switch", "stay", "switch", "stay", "stay",
                  "stay", "switch", "switch", "stay", "stay"])
  cost_sw = .5
  if not isclose(get_value(measurement, act, cost_sw), .1):
    raise ExcerciseError("'value_function' function is not correctly implemented!")
  print("Well Done!")

セクション 1: 問題の分析

# @title Video 1: Gone fishing
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', '3oIwUFpolVA'), ('Bilibili', 'BV1FL411p7o5')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Gone_fishing_Video")

問題設定

1. 状態のダイナミクス: 魚の位置は左か右の2つの可能性があります。秘密裏に、各時間ステップで魚は一定の確率 psw=1pstayp_{\rm sw} = 1 - p_{\rm stay} で左右を切り替えます。これはLinear Systemsの日に見た二値スイッチングモデル(テレグラフ過程)です。魚の位置 sfishs^{\rm fish} は潜在変数であり、魚を捕まえようとしたときに得られる観測により情報が得られます。これにより、過去の観測に基づく現在の位置の事後確率、すなわち信念が得られます。

2. 行動: 過去の日とは異なり、今度はプロセスに行動を起こせます!現在の位置(左または右)に留まるか、反対側に移動することができます。

3. 報酬とコスト: 捕まえた魚1匹につき1「ポイント」の報酬が得られます。魚と同じ側にいる場合、離散時間ステップごとに確率 qhighq_{\rm high} で魚を捕まえられます。そうでない場合でも確率 qlowq_{\rm low} で魚を捕まえられることがあります。

反対側に移動すると CC ポイントのコストがかかります。賢く決めましょう!


効用の最大化

賢く決めて総効用(総ポイント)を最大化するために、どんな状況でも何をすべきかを示す方策に従います。ここで状況はあなたの位置と魚の位置に関する信念 btb_t(事後分布)で決まります(魚の位置は潜在変数であることを思い出してください)。

最適制御理論では、信念は過去のすべての観測に基づく潜在変数の事後確率です。この事後に関して期待効用を最大化することが最適であることが示されています。

本問題では、魚は左か右のどちらかにいるため信念は1つの数値で表せます。すなわち:

bt=p(stfish=Right  m0:t,a0:t1)b_t = p(s^{\rm fish}_t = {\rm Right}\ |\ m_{0:t}, a_{0:t-1})

ここで m0:tm_{0:t} は観測、a0:t1a_{0:t-1} は行動(留まるか切り替えるか)です。

最後に、方策は信念に基づく単純な閾値でパラメータ化します:魚が現在いる側にいる確率が閾値 θ\theta を下回ったら反対側に切り替えます。

このチュートリアルで、適切な閾値を選べばこの単純な方策が最適であることを発見します!

インタラクティブデモ 1: 魚のダイナミクスの観察

このデモでは、魚が左右に動くダイナミクスを、あなたが一箇所に留まったまま観察します。魚が同じ場所に留まる確率 stay_prob を操作し、魚の動きの様子を観察してください。

考える質問:

# @markdown Execute this cell to enable the demo.
display(HTML('''<style>.widget-label { min-width: 15ex !important; }</style>'''))

@widgets.interact(p_stay=widgets.FloatSlider(.9, description="stay_prob", min=0., max=1., step=0.01))

def update_ex_1(p_stay):
  """
    T: Length of timeline
    p_stay: probability that the fish do not swim to the other side at time t
  """
  params = [p_stay, _, _, _]

  # initial condition: fish [fish_initial] start at the left location (-1)
  binaryHMM_test = binaryHMM(params=params, fish_initial=1, T=100)

  fish_state = binaryHMM_test.fish_dynamics()
  plot_fish(fish_state)

plt.show()

解答を見る$

# @title Submit your feedback
content_review(f"{feedback_prefix}_Examining_fish_dynamics_Interactive_Demo_and_Discussion")

セクション 2: 魚を捕まえる

# @title Video 2: Catch some fish
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'ZjB2_SAY2uE'), ('Bilibili', 'BV1kD4y1m7Lo')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Catch_some_fish_Video")

インタラクティブデモ 2: 報酬関数の観察

この2つ目のデモでは、ボタンであなたの位置を制御しますが、魚の位置は stay_prob = 1 に固定します。魚が一箇所で穏やかに泳いでいる状態で、魚と同じ側にいるときと反対側にいるときの報酬を視覚的に確認します。

魚と同じ側にいるときは魚を捕まえる確率が高くなるはずです(ただし、スライダーを操作して他の条件に変えることも技術的には可能です!)。

スライダー high_rew_prob(魚と同じ側にいるときの高い報酬確率)と low_rew_prob(反対側にいるときの低い報酬確率)を操作してください。ボタン(同じ場所 vs. 違う場所)が魚を捕まえる確率を決めます。

考える質問:

# @markdown Execute this cell to enable the demo.
display(HTML('''<style>.widget-label { min-width: 15ex !important; }</style>'''))

@widgets.interact(locs=widgets.RadioButtons(options=['same location', 'different locations'],
                                            description='Fish and agent:',
                                            disabled=False,
                                            layout={'width': 'max-content'}),
                  p_low_rwd=widgets.FloatSlider(.1, description="low_rew_prob:",
                                                min=0., max=1.),
                  p_high_rwd=widgets.FloatSlider(.9, description="high_rew_prob:",
                                                 min=0., max=1.))

def update_ex_2(locs, p_low_rwd, p_high_rwd):
  """
    p_stay: probability of fish staying at current side at time t
    p_low_rwd: probability of catching fish when you're NOT on the side where the fish are swimming
    p_high_rwd: probability of catching fish when you're on the side where the fish are swimming
    fish_initial: initial side of fish (-1 left, 1 right)
    agent_initial: initial side of the agent (YOU!) (-1 left, 1 right)
  """
  p_stay = 1
  params = [p_stay, p_low_rwd, p_high_rwd, _]

  # initial condition for fish [fish_initial] and you [loc_initial]
  if locs == 'same location':
    binaryHMM_test = binaryHMM(params, fish_initial=0, loc_initial=0, T=100)
  else:
    binaryHMM_test = binaryHMM(params, fish_initial=1, loc_initial=0, T=100)

  fish_state, loc, measurement = binaryHMM_test.generate_process_lazy()
  plot_measurement(measurement)

解答を見る$

# @title Submit your feedback
content_review(f"{feedback_prefix}_Examining_the_reward_function_Interactive_Demo_and_Discussion")

セクション 3: 信念のダイナミクスと信念分布

# @title Video 3: Where are the fish?
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'rmETVsRFYGk'), ('Bilibili', 'BV19t4y1Q7VH')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Where_are_the_fish_Video")

インタラクティブデモ 3: 信念の観察

ここで、信念がどのように計算されるかの直感を得ましょう。ここでの信念は、観測に基づく魚の位置の事後確率 p(stm0:t)p(s_t|m_{0:t}) です。これはHidden Dynamicsの日に行ったことと同じです!

この演習では、あなたは常に左側に留まりますが、魚は動き回ります。魚は確率 stay_prob で同じ側に留まります。魚の群れの位置は見えず、捕まえた魚だけが観測です。これらの観測を使って魚の群れの位置を推定します。

デモではスライダー high_rew_problow_rew_probstay_prob を操作してください。

考える質問:

# @markdown Execute this cell to enable the demo.
display(HTML('''<style>.widget-label { min-width: 15ex !important; }</style>'''))

@widgets.interact(p_stay=widgets.FloatSlider(.96, description="stay_prob",
                                             min=.8, max=1., step=.01),
                  p_low_rwd=widgets.FloatSlider(.1, description="low_rew_prob",
                                                min=0., max=1., step=.01),
                  p_high_rwd=widgets.FloatSlider(.3, description="high_rew_prob",
                                                 min=0., max=1., step=.01))

def update_ex_2(p_stay, p_low_rwd, p_high_rwd):
  """
    T: Length of timeline
    p_stay: probability of fish staying at current side at time t
    p_high_rwd: probability of catching fish when you're on the side where the fish are swimming
    p_low_rwd: probability of catching fish when you're NOT on the side where the fish are swimming
    fish_initial: initial side of fish (0 left, 1 right)
    agent_initial: initial side of the agent (YOU!) (0 left, 1 right)
    threshold: threshold of belief below which the action is switching
  """
  threshold = 0.2
  params = [p_stay, p_low_rwd, p_high_rwd, threshold]

  binaryHMM_test = binaryHMM_belief(params, choose_policy="lazy",
                                    fish_initial=0, loc_initial=0, T=100)

  belief, loc, act, measurement, fish_state = binaryHMM_test.generate_process()
  plot_dynamics(belief, loc, act, measurement, fish_state,
                binaryHMM_test.choose_policy)

解答を見る$

# @title Submit your feedback
content_review(f"{feedback_prefix}_Examining_the_beliefs_Interactive_Demo_and_Discussion")

セクション 4: 閾値方策の実装

# @title Video 4: How should you act?
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'cTzaQl2Vxn4'), ('Bilibili', 'BV1ri4y137cj')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_How_should_you_act_Video")

コーディング演習 4: 閾値ベース ポリシーに従う動的挙動

ここでは、上で使った「怠け者」ポリシーから、あなたが実装する閾値ポリシーに切り替えます。信念が十分に低くて魚がいる側にいると判断したら、位置を変えます。関数 policy_threshold(threshold, belief, loc) を更新してください。このポリシーは3つの入力を取ります:

  1. 魚の状態に関する belief(信念)。便宜上、時刻 t における信念は2次元ベクトルで表します。1つ目の要素は魚が左にいる信念、2つ目は魚が右にいる信念です。各時刻でこれらの要素は合計1になります。

  2. あなたの位置 loc は「左」= -1、「右」= 1 で表します。

  3. スイッチするかどうかを決める信念の threshold(閾値)。魚と同じ側にいる信念がこの閾値を下回ったら、反対側に移動し、それ以外はそのまま留まります。

関数は各時刻 t に対して、「stay」または「switch」のいずれかの行動を返すようにしてください。

def policy_threshold(threshold, belief, loc):
  """
  chooses whether to switch side based on whether the belief
      on the current site drops below the threshold

  Args:
    threshold (float): the threshold of belief on the current site,
                        when the belief is lower than the threshold, switch side
    belief (numpy array of float, 2-dimensional): the belief on the
                                                  two sites at a certain time
    loc (int) : the location of the agent at a certain time
                -1 for left side, 1 for right side

  Returns:
    act (string): "stay" or "switch"
  """

  ############################################################################
  ## 1. Modify the code below to generate actions (stay or switch)
  ##    for current belief and location
  ##
  ## Belief is a 2d vector: first element = Prob(fish on Left | measurements)
  ##                       second element = Prob(fish on Right  | measurements)
  ## Returns "switch" if Belief that fish are in your current location < threshold
  ##         "stay" otherwise
  ##
  ## Hint: use loc value to determine which row of belief you need to use
  ##       see the docstring for more information about loc
  ##
  ## 2. After completing the function, comment this line:
  raise NotImplementedError("Student exercise: Please complete the code")
  ############################################################################
  # Write the if statement
  if ...:
    # action below threshold
    act = ...
  else:
    # action above threshold
    act = ...

  return act


# Next line tests your function
test_policy_threshold()

ご覧ください

よくできました!

解答を見る$

# @title Submit your feedback
content_review(f"{feedback_prefix}_Dynamics_threshold_based_policy_Exercise")

インタラクティブデモ 4: 異なる閾値での動的挙動

以下のデモは、あなたが作成したポリシーを使っています!スライダーを操作して、ポリシーで制御される動的挙動を観察してください。

(コードでは stay_prob=0.95high_rew_prob=0.3low_rew_prob=0.1 と指定しています。これらは変更可能ですが、妥当なパラメータです。注意:閾値の変化を段階的に見るには同じ乱数を使い続けてください。異なる例を見たい場合はシードをリフレッシュしてください。)

考える質問:

# @markdown Execute this cell to enable the demo.
display(HTML('''<style>.widget-label { min-width: 15ex !important; }</style>'''))

@widgets.interact(threshold=widgets.FloatSlider(.2, description="threshold", min=0., max=1., step=.01),
                  new_seed=widgets.ToggleButtons(options=['Reusing', 'Refreshing'],
                                                  description='Random seed:',
                                                  disabled=False,
                                                  button_style='', # 'success', 'info', 'warning', 'danger' or '',
                                                  icons=['check'] * 2
                                                  ))
def update_ex_4(threshold, new_seed):
  """
    p_stay: probability fish stay
    high_rew_p: p(catch fish) when you're on their side
    low_rew_p : p(catch fish) when you're on other side
    threshold: threshold of belief below which switching is taken

  """
  if new_seed == "Refreshing":
    get_randomness(T=100)

  stay_prob=.95
  high_rew_p=.3
  low_rew_p=.1

  params = [stay_prob, high_rew_p, low_rew_p, threshold]

  # initial condition for fish [fish_initial] and you [loc_initial]
  binaryHMM_test = binaryHMM_belief(params, fish_initial=0, loc_initial=0, choose_policy="threshold", T=100)

  belief, loc, act, measurement, fish_state = binaryHMM_test.generate_process()
  plot_dynamics(belief, loc, act, measurement,
                fish_state, binaryHMM_test.choose_policy)
  plt.show()

解答を見る$

# @title Submit your feedback
content_review(f"{feedback_prefix}_Dynamics_with_different_thresholds_Interactive_Demo_and_Discussion")

セクション 5: 価値関数の実装

# @title Video 5: Evaluate policy
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'aJhffROC74w'), ('Bilibili', 'BV1TD4y1D7K3')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Evaluate_policy_Video")

コーディング演習 5.1: 価値関数の実装

閾値がどれほど良いかを調べましょう。そのために、効用(合計ポイント)を定量化する価値関数を計算します。この価値を使って異なる閾値を比較します。目標は、位置を変える努力を最小限にしつつ、釣れる魚の量を最大化することでした。

価値は単位時間あたりの期待効用の合計です。

V(θ)=1T(tR(st)C(at))V(\theta) = \frac{1}{T}\left( \sum_t R(s_t) - C(a_t) \right)

ここで、R(st)R(s_t) は位置 sts_t で得られる瞬間的報酬、C(at)C(a_t) は選択した行動に対するコストです。魚を釣ると1ポイント得て、位置を変えると cost_sw ポイントのコストを支払います。

報酬と行動の確率に基づいて数学的に平均を取ることもできますが、長時間にわたる_実際の_報酬とコストの平均を取ることで同じ答えが得られます。これを行います。

指示:関数 get_value(rewards, actions, cost_sw) を完成させてください。

def get_value(rewards, actions, cost_sw):
  """
  value function

  Args:
    rewards (numpy array of length T): whether a reward is obtained (1) or not (0) at each time step
    actions (numpy array of length T): action, "stay" or "switch", taken at each time step.
    cost_sw (float): the cost of switching to the other location

  Returns:
    value (float): expected utility per unit time
  """
  actions_int = (actions == "switch").astype(int)

  ############################################################################
  ## 1. Modify the code below to compute the value function (equation V(theta))
  ##
  ## 2. After completing the function, comment this line:
  raise NotImplementedError("Student exercise: Please complete the code")
  ############################################################################
  # Calculate the value function
  value = ...

  return value


# Test your function
test_value_function()

ご覧ください

よくできました!

解答を見る$

# @title Submit your feedback
content_review(f"{feedback_prefix}_Implementing_a_value_function_Exercise")

コーディング演習 5.2: ポリシーを実行する

しきい値の良さを判定する仕組みができたので、ブルートフォース法を使って最適なしきい値を計算しましょう。すべてのしきい値を試し、それぞれの価値をシミュレーションして、最も良いものを選びます。関数 get_optimal_threshold(p_stay, low_rew_p, high_rew_p, cost_sw) を完成させてください。関数の出力を可視化するコードは提供しています。このプロットでどのしきい値が最大の効用を持つか観察しましょう。

考える質問:

# Set a large time horizon to calculate meaningful statistics
large_time_horizon = 10000
get_randomness(large_time_horizon)

def run_policy(threshold, p_stay, low_rew_p, high_rew_p):
  """
  This function executes the policy (fully parameterized by the threshold) and
  returns two arrays:
    The sequence of actions taken from time 0 to T
    The sequence of rewards obtained from time 0 to T
  """

  params = [p_stay, low_rew_p, high_rew_p, threshold]
  binaryHMM_test = binaryHMM_belief(params, choose_policy="threshold", T=large_time_horizon)
  _, _, actions, rewards, _ = binaryHMM_test.generate_process()
  return actions, rewards


def get_optimal_threshold(p_stay, low_rew_p, high_rew_p, cost_sw):
  """
  Args:
    p_stay (float): probability of fish staying in their current location
    low_rew_p (float): probability of catching fish when you and the fist are in different locations.
    high_rew_p (float): probability of catching fish when you and the fist are in the same location.
    cost_sw (float): the cost of switching to the other location

  Returns:
    value (float): expected utility per unit time
  """
  ############################################################################
  ## 1. Modify the code below to find the best threshold using brute force
  ##
  ## 2. After completing the function, comment this line:
  raise NotImplementedError("Student exercise: Please complete the code")
  ############################################################################

  # Create an array of 20 equally distanced candidate thresholds (min = 0., max=1.):
  threshold_array = ...

  # Using the function get_value() that you coded before and
  # the function run_policy() that we provide, compute the value of your
  # candidate thresholds:

  # Create an array to store the value of each of your candidates:
  value_array = ...

  for i in ...:
    actions, rewards = ...
    value_array[i] = ...

  # Return the array of candidate thresholds and their respective values

  return threshold_array, value_array


# Feel free to change these parameters
stay_prob = .9         # Fish stay at current location with probability stay_prob
low_rew_prob = 0.1     # Even if fish are somewhere else, you can catch some fish with probability low_rew_prob
high_rew_prob = 0.7    # When you and the fish are in the same place, you can catch fish with probability high_rew_prob
cost_sw = .1           # When you switch locations, you pay this cost: cost_sw


# Visually determine the threshold that obtains the maximum utility.
# Remember, policies are parameterized by a threshold on beliefs:
# when your belief that the fish are on your current side falls below a threshold 𝜃, you switch to the other side.
threshold_array, value_array = get_optimal_threshold(stay_prob, low_rew_prob, high_rew_prob, cost_sw)
plot_value_threshold(threshold_array, value_array)

解答を見る$

出力例:

Solution hint
# @title Submit your feedback
content_review(f"{feedback_prefix}_Run_the_policy_Exercise_and_Discussion")

まとめ

このチュートリアルでは、隠れマルコフモデルと行動を組み合わせて最適制御問題を解きました!これにより、部分観測マルコフ決定過程(POMDP)の基本的な形式が示されました。

観測(釣れた魚)を使って、魚の位置を推定するための信念(事後分布)を構築しました。次に、異なるポリシーの質を評価するための価値関数を計算しました。最後にブルートフォース法を用いて、場所の切り替えの労力を最小限に抑えつつ、できるだけ多くの魚を釣る最適なポリシーを発見しました。

次のチュートリアルでは、ここで使った二値の状態と行動の代わりに連続状態と連続行動を扱います。連続制御でもPOMDPを使うことはできますが、ポリシーが依然として示唆に富むため、完全に観測可能な場合の制御、すなわちマルコフ決定過程(MDP)に焦点を当てます。

# @title Video 6: From discrete to continuous control
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'ndCMgdjv9Gg'), ('Bilibili', 'BV1JA411v7jy')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_From_discrete_to_continuous_control_Video")

ボーナス

ボーナスセクション 1: 最適ポリシーは課題にどう依存するか?

# @title Video 7: Sensitivity of optimal policy
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'wd8IVsKoEfA'), ('Bilibili', 'BV1QK4y1e7N9')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Sensitivity_of_optimal_policy_Bonus_Video")

ボーナスインタラクティブデモ 1: 課題パラメータを探る

このデモでは、さまざまな課題パラメータを操作できます。以下を調整したときに最適なしきい値がどう変わるか観察してください:

これらのパラメータが変わると最適なしきい値がなぜ変わるのか説明できますか?

異なるポリシーの価値の微妙な変化を見るには長いシミュレーションが必要な場合があるので、まずは大まかな傾向を探してください。

# @markdown Make sure you execute this cell to enable the widget!
display(HTML('''<style>.widget-label { min-width: 15ex !important; }</style>'''))

@widgets.interact(p_stay=widgets.FloatSlider(.95, description="p(stay)",
                                             min=0., max=1.),
                  p_high_rwd=widgets.FloatSlider(.4, description="p(high_rwd)",
                                                 min=0., max=1.),
                  p_low_rwd=widgets.FloatSlider(.1, description="p(low_rwd)",
                                                min=0., max=1.),
                  cost_sw=widgets.FloatSlider(.2, description="switching cost",
                                              min=0., max=1., step=.01))


def update_ex_bonus(p_stay, p_high_rwd, p_low_rwd, cost_sw):
  """
    p_stay: probability fish stay
    high_rew_p: p(catch fish) when you're on their side
    low_rew_p : p(catch fish) when you're on other side
    cost_sw: switching cost
  """

  # Set a large time horizon to calculate meaningful statistics
  large_time_horizon = 10000
  get_randomness(large_time_horizon)

  threshold_array, value_array = get_optimal_threshold(p_stay,
                                                       p_low_rwd,
                                                       p_high_rwd,
                                                       cost_sw)
  plot_value_threshold(threshold_array, value_array)

解答を見る$

# @title Submit your feedback
content_review(f"{feedback_prefix}_Explore_task_parameters_Bonus_Interactive_Demo_and_Discussion")