Open In Colab   Open in Kaggle

チュートリアル 1: 基本的な強化学習

第3週4日目: 基本的な強化学習

Neuromatch Academyによる

コンテンツ作成者: パブロ・サミュエル・カストロ

コンテンツレビュアー: シャオナン・ワン、シャオメイ・ミ、ジュリア・コスタクルタ、ドーラ・ジーユ・ヤン、アドリタ・ダス

コンテンツ編集者: シャオナン・ワン

制作編集者: スピロス・チャヴリス


チュートリアルの目的

強化学習(RL)は、エージェントが報酬を最大化する行動を学習する問題を定義し解決するための強力なフレームワークです。基本的には、エージェントは現在の世界の状態を観察し、行動を選択し、報酬を受け取り、このフィードバックを使って将来の行動を改善します。RLはこの学習プロセスを形式的かつ最適に記述する方法を提供し、これはもともと動物の行動研究から導かれ、その後人間や動物の脳の観察によって検証されました。

このチュートリアルでは、簡単な例を使ってRLの基本概念を紹介します。最後には、RLがどのように機能し、さまざまな問題の解決にどのように応用できるかをより深く理解できるようになります。

# @title Tutorial slides
from IPython.display import IFrame
link_id = "ztgws"
print(f"If you want to download the slides: https://osf.io/download/{link_id}/")
IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render", width=854, height=480)

セットアップ

このノートブックはGPU不要です!

# @title Install and import feedback gadget


from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch_dl",
            "user_key": "f379rz8y",
        },
    ).render()


feedback_prefix = "W3D4_T1"
# Imports
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Tuple

セクション1: 強化学習の歴史

このセクションでは、強化学習(RL)の歴史を逆時系列で簡単に概観します。これにより、なぜRLが興味深いテーマなのかを理解する助けになります!

# @title Video 1: Intro to RL
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'x-NyX8bRsTQ'), ('Bilibili', 'BV1wk4y1M7ae')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Intro_to_RL_Video")

セクション2: 強化学習とは何か

非常にシンプルな問題を使って、RLとは何か、問題定式化を定義する主な構成要素について高レベルで説明します。

補足: もっと読みたい方は、RLの定番参考書であるサットン&バルトの強化学習の本を参照してください。

セクション2.1: グリッドワールド

グリッドワールドは非常にシンプルな「ナビゲーション」問題で、RLの問題や解決策を動機付けるのに非常に役立ちます。RL研究でよく使われるので、慣れておくと良いでしょう!

このチュートリアルでは、報酬が一角にある空の部屋というシンプルなグリッドワールド問題を使います。

以下の例は、少し難しい2つ目のグリッドワールドを定義しています。自由に自分のグリッドワールドを作ってみてください!

補足: ウェブ上でグリッドワールドのRLを試したい場合は、このGridWorldプレイグラウンドウェブアプリをチェックしてみてください。

# @title Video 2: Grid World
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', '4r400A5GNfE'), ('Bilibili', 'BV1GV411M7hk')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Grid_world_Video")

コーディング演習1: グリッドワールドの最短経路プランナーをコードしよう

# @title Create the GridWorldPlanner object (defaults to simple example)

ASCII_TO_EMOJI = {
    ' ': '⬜',
    '*': '⬛',
    'g': '⭐',
    '<': '◀️',
    '>': '▶️',
    'v': '🔽',
    '^': '🔼',
}

ACTIONS = ['<', '>', 'v', '^']
ACTION_EFFECTS = {  # Position effects of each action.
    '<': (0, -1),
    '>': (0, 1),
    'v': (1, 0),
    '^': (-1, 0),
}


def get_emoji(c, policy=None):
  assert c in ASCII_TO_EMOJI
  if policy is not None and c != 'g':
    assert policy in ASCII_TO_EMOJI
    if policy != ' ':  # If there is a policy, use this instead.
      c = policy
  return ASCII_TO_EMOJI[c]


class GridWorldBase(object):
  """Defines a GridWorldPlanner object."""

  def __init__(self, world_spec: Optional[np.ndarray] = None):
    """Creates a GridWorld object with an empty policy.

    Args:
      world_spec: Optional array specification of GridWorld. If None, will
                  use default square room.
    """
    if world_spec is None:
      self.world_spec = np.array(
          [['*', '*', '*', '*', '*', '*'],
           ['*', ' ', ' ', ' ', ' ', '*'],
           ['*', ' ', ' ', ' ', ' ', '*'],
           ['*', ' ', ' ', ' ', ' ', '*'],
           ['*', ' ', ' ', ' ', 'g', '*'],
           ['*', '*', '*', '*', '*', '*']]
      )
    else:
      assert len(world_spec.shape) == 2
      self.world_spec = world_spec

    assert len(np.where(self.world_spec == 'g')[0]) == 1  # Only one goal.
    self.policy = np.full_like(self.world_spec, ' ')
    # **Note**: These may be useful for your planner!
    self.goal_cell = [x[0] for x in np.where(self.world_spec == 'g')]

  def get_neighbours(self, cell: Tuple[int, int]):
    """Get the neighbours of a cell.

    **Note**: You should use this when writing your planner!

    Args:
      cell: cell position.

    Returns:
      Dict containing neighbouring cells for each of the 4 possible directions.
    """
    height, width = self.world_spec.shape
    i, j = cell
    if i < 1 or i >= height or j < 1 or j >= width:
      raise ValueError(f'Invalid cell position: {cell}')
    neighbours = {}
    for a in ACTIONS:
      delta = ACTION_EFFECTS[a]
      neighbour_pos = [i + delta[0], j + delta[1]]
      if (neighbour_pos[0] < 0 or neighbour_pos[1] < 0 or
          neighbour_pos[0] >= height or neighbour_pos[1] >= width or
          self.world_spec[neighbour_pos[0], neighbour_pos[1]] == '*'):
        # Remain in same cell
        neighbours[a] = cell
      else:
        neighbours[a] = neighbour_pos
    return neighbours

  def plan(self):
    """Constructs a random policy.

    **Note**: you will make something better further down!
    """
    for i in range(self.policy.shape[0]):
      for j in range(self.policy.shape[1]):
        if self.world_spec[i, j] == '*':  # Nothing to do for walls.
          continue
        self.policy[i, j] = ACTIONS[np.random.choice(len(ACTIONS))]

  def draw(self, include_policy: bool = False):
    """Draw the grid, and (optionally) include the policy."""
    for i in range(len(self.world_spec)):
      row_range = range(len(self.world_spec[i]))
      if include_policy:
        row_chars = [get_emoji(self.world_spec[i, j], self.policy[i, j]) for j in row_range]
      else:
        row_chars = [get_emoji(self.world_spec[i, j], None) for j in row_range]
      print(''.join(row_chars))


gwb = GridWorldBase()
print('Simple GridWorld:')
gwb.draw()
gwb.plan()
print('Random policy:')
gwb.draw(True)
class GridWorldPlanner(GridWorldBase):
  """A GridWorld that finds a better policy."""

  def plan(self):
    """Define a better planner!

    This gives you a starting point by setting the proper actions in the cells
    surrounding the goal cell.

    **Assignment:** Do the rest!
    """
    super().plan()
    goal_queue = [self.goal_cell]
    goals_done = set()
    goal = goal_queue.pop(0)  # pop from front of list
    goal_neighbours = self.get_neighbours(goal)
    goals_done.add(tuple(goal))

    for a in goal_neighbours:
      nbr = tuple(goal_neighbours[a])
      if nbr == goal:
        continue
      if nbr not in goals_done:
        if a == '<':
          self.policy[nbr[0], nbr[1]] = '>'
        elif a == '>':
          self.policy[nbr[0], nbr[1]] = '<'
        elif a == '^':
          self.policy[nbr[0], nbr[1]] = 'v'
        else:
          self.policy[nbr[0], nbr[1]] = '^'
        goal_queue.append(nbr)


gwp = GridWorldPlanner()
print('Simple GridWorld:')
gwp.draw()
gwp.plan()
print('Better policy:')
gwp.draw(True)

より良いプランナーを作成しましょう!

class GridWorldPlanner(GridWorldBase):
  """A GridWorld that finds a better policy."""

  def plan(self):
    """Define a better planner!

    This gives you a starting point by setting the proper actions in the cells
    surrounding the goal cell.

    **Assignment:** Do the rest!
    """
    super().plan()
    goal_queue = [self.goal_cell]
    goals_done = set()
    #################################################
    # Implement a better planer
    raise NotImplementedError("Define a better planner!`")
    #################################################
    while ...:
      goal = goal_queue.pop(0)  # pop from front of list
      goal_neighbours = self.get_neighbours(goal)
      goals_done.add(tuple(goal))
      for a in goal_neighbours:
        nbr = tuple(goal_neighbours[a])
        if nbr == goal:
          continue
        if nbr not in goals_done:
          if a == '<':
            self.policy[nbr[0], nbr[1]] = '>'
          elif a == '>':
            self.policy[nbr[0], nbr[1]] = '<'
          elif a == '^':
            self.policy[nbr[0], nbr[1]] = 'v'
          else:
            self.policy[nbr[0], nbr[1]] = '^'
          goal_queue.append(nbr)

解答を見る$

gwp = GridWorldPlanner()
print('Simple GridWorld:')
gwp.draw()
gwp.plan()
print('Better policy:')
gwp.draw(True)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Make_a_better_planner_Exercise")
# @title Try it out in a harder problem.
harder_grid = np.array(
    [['*', '*', '*', '*', '*', '*', '*', '*', '*'],
     ['*', ' ', ' ', ' ', '*', ' ', ' ', 'g', '*'],
     ['*', ' ', ' ', ' ', '*', ' ', ' ', ' ', '*'],
     ['*', ' ', ' ', ' ', '*', ' ', ' ', ' ', '*'],
     ['*', ' ', ' ', ' ', '*', ' ', ' ', ' ', '*'],
     ['*', ' ', ' ', ' ', ' ', ' ', ' ', ' ', '*'],
     ['*', '*', '*', '*', '*', '*', '*', '*', '*'],
    ]
)
gwb_2 = GridWorldBase(harder_grid)
gwp_2 = GridWorldPlanner(harder_grid)
print('Harder GridWorld:')
gwb_2.draw()
gwb_2.plan()
print('Random policy:')
gwb_2.draw(True)
print('Better policy:')
gwp_2.plan()
gwp_2.draw(True)

セクション2.2: マルコフ決定過程(MDP)

RL問題の定式化は伝統的にマルコフ決定過程(MDP)を通じて行われます。このセクションでは、必要な記法を紹介し、シンプルなグリッドワールドに対応するMDPを定義するコードを書きます。

補足: マーティン・プターマンのマルコフ決定過程に関する本は、さらに詳しく学びたい方に最適な参考書です。

# @title Video 3: Markov Decision Process
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'AfIsG1I4MRE'), ('Bilibili', 'BV1VV411M7pX')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Markov_Decision_Process_Video")

コーディング演習2: グリッドワールド仕様からMDPを作成しよう

MDPのPP行列とRR行列を作成してください。

class MDPBase(object):
  """This object creates a proper MDP from a GridWorld object."""

  def __init__(self, grid_world: GridWorldBase):
    """Constructs an MDP from a GridWorldBase object.

    Args:
      grid_world: GridWorld specification.
    """
    # Determine how many valid states there are and create empty matrices.
    self.grid_world = grid_world
    self.num_states = np.sum(grid_world.world_spec != '*')
    self.num_actions = len(ACTIONS)
    self.P = np.zeros((self.num_states, self.num_actions, self.num_states))
    self.R = np.zeros((self.num_states, self.num_actions))
    self.pi = np.zeros(self.num_states, dtype=np.int32)
    # Create mapping from cell positions to state ID.
    state_idx = 0
    self.cell_to_state = np.ones(grid_world.world_spec.shape, dtype=np.int32) * -1  # Defaults to -1.
    self.state_to_cell = {}
    for i, row in enumerate(grid_world.world_spec):
      for j, cell in enumerate(row):
        if cell == '*':
          continue
        if cell == 'g':
          self.goal_state = state_idx
        self.cell_to_state[i, j] = state_idx
        self.state_to_cell[state_idx] = (i, j)
        state_idx += 1
    #################################################
    # States should be numbered from left-to-right and from top-to-bottom.
    raise NotImplementedError("Calculate P and R")
    #################################################
    # Assign transition probabilities and rewards accordingly.
    for s in range(...):
      neighbours = ...
      for a, action in enumerate(neighbours):
        nbr = ...
        s2 = self.cell_to_state[..., ...]
        self.P[s, a, s2] = 1.0  # Deterministic transitions
        if s2 == self.goal_state:
          self.R[s, a] = 1.0

  def draw(self, include_policy: bool = False):
    # First make sure we convert our MDP policy into the GridWorld policy.
    for s in range(self.num_states):
      r, c = self.state_to_cell[s]
      self.grid_world.policy[r, c] = ACTIONS[self.pi[s]]
    self.grid_world.draw(include_policy)

  def plan(self):
    """Define a planner
    """
    goal_queue = [self.goal_state]
    goals_done = set()
    #################################################
    # Set the proper actions
    raise NotImplementedError("Implement `plan` function!")
    #################################################
    while goal_queue:
      goal = goal_queue.pop(0)  # pop from front of list
      nbr_states, nbr_actions = ...
      goals_done.add(goal)
      for s, a in zip(..., ...):
        if s == goal:
          continue
        if s not in goals_done:
          self.pi[s] = ...
          goal_queue.append(s)

解答を見る$

mdpb = MDPBase(gwb)

# Verify the transitions were properly created.
for i, row in enumerate(mdpb.grid_world.world_spec):
  for j, cell in enumerate(row):
    if cell == '*':
      continue
    neighbours = mdpb.grid_world.get_neighbours((i, j))
    s = mdpb.cell_to_state[i, j]
    for a, action in enumerate(neighbours):
      nbr = neighbours[action]
      s2 = mdpb.cell_to_state[nbr[0], nbr[1]]
      assert np.sum(mdpb.P[s, a, :]) == 1.0
      assert mdpb.P[s, a, s2] == 1.0
      if s2 == mdpb.goal_state:
        assert mdpb.R[s, a] == 1.0
      else:
        assert mdpb.R[s, a] == 0.0

print('P and R matrices successfully created!')
print('GridWorld:')
mdpb.draw()
print('Shortest path policy:')
mdpb.plan()
mdpb.draw(True)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Create_an_MDP_Exercise")

セクション2.3: QQ

QQ値はRLアルゴリズムの中心であり、特定の状態で特定の行動を取ることの「望ましさ」を定量化します。エージェントは訓練を通じてこれらの値を更新し、その推定値を使って行動を決定します。

# @title Video 4: Q-values
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'u2-uqRiJHuM'), ('Bilibili', 'BV1gk4y1M7iK')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Q_values_Video")

コーディング演習3: ステップ数残りを解くソルバーを作成しよう

ステップ数残りをQQ値として保持する新しいMDPクラスを作成してください。

class MDPToGo(MDPBase):

  def __init__(self, grid_world: GridWorldBase):
    """Constructs an MDP from a GridWorldBase object.

    States should be numbered from left-to-right and from top-to-bottom.

    Args:
      grid_world: GridWorld specification.
    """
    super().__init__(grid_world)
    self.Q = np.zeros((self.num_states, self.num_actions))

  def computeQ(self):
    """Store discounted steps-to-go in an SxA matrix called Q.

    This matrix will then be used to extract the optimal policy.
    """
    #################################################
    # Implement a function to compute Q
    raise NotImplementedError("Implement `ComputeQ` function!")
    #################################################
    goal_queue = [(self.goal_state, 0)]
    goals_done = set()
    while goal_queue:
      goal, steps_to_go = ...  # pop from front of list
      steps_to_go += ...  # Increase the number of steps to goal.
      nbr_states, nbr_actions = ...
      goals_done.add(...)
      for s, a in zip(..., ...):
        if goal == self.goal_state and s == self.goal_state:
          self.Q[s, a] = ...
        elif s == goal:
          # If (s, a) leads to itself then we have an infinite loop (since
          # we're assuming deterministic transitions).
          self.Q[s, a] = ...
        else:
          self.Q[s, a] = ...
        if s not in goals_done:
          goal_queue.append((..., ...))

  def plan(self):
    """Now planning is just doing an argmin over the Q-values!

    Note that this is a little different than standard Q-learning (where we do
    an argmax), since our Q-values currently store steps-to-go.
    """
    self.pi = np.argmin(self.Q, axis=-1)

解答を見る$

mdpTg = MDPToGo(gwb)
print('GridWorld:')
mdpTg.draw()
# Compute Q, then extract policy from it.
mdpTg.computeQ()
mdpTg.plan()
print('Optimal policy:')
mdpTg.draw(True)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Create_a_step_to_go_solver_Exercise")

セクション3: 価値反復

価値反復は、ベルマンバックアップを繰り返し実行してQQ値とVV値の推定を継続的に改善する反復アルゴリズムです。これはPPRRへのアクセスを前提としています(通常RLではアクセスできません)が、後で説明するQQ学習の基礎となっています。


知っていましたか? リチャード・ベルマンは価値反復のために動的計画法(コンピュータサイエンスの基礎の一部)を開発しました。

# @title Video 5: Value iteration
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'XIcX37uaRF0'), ('Bilibili', 'BV1SF41197kC')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Value_Iteration_Video")

コーディング演習4: 価値反復を実装しよう

価値反復を行う新しいMDPクラスを作成してください。

class MDPValueIteration(MDPToGo):

  def __init__(self, grid_world: GridWorldBase, gamma: float = 0.99):
    """Constructs an MDP from a GridWorldBase object.

    States should be numbered from left-to-right and from top-to-bottom.

    Args:
      grid_world: GridWorld specification.
      gamma: Discount factor.
    """
    super().__init__(grid_world)
    self.gamma = gamma

  def computeQ(self, error_tolerance : float = 1e-5):
    """Compute Q and V vectors via value iteration.

    Args:
      error_tolerance: How much error we tolerate between successive Q updates.
    """
    self.Q = np.zeros((self.num_states, self.num_actions))
    num_iterations = 0
    error = np.inf
    #################################################
    # Write this method!
    # First find Q, and then extract V from Q.
    # Hint: Use matrix multiplication instead of for loops!
    raise NotImplementedError("Implement `computeQ` function!")
    #################################################
    while error > error_tolerance:
      new_Q = ...
      max_next_Q = ...
      for a in range(self.num_actions):
        new_Q[:, a] = ...
      error = ...
      self.Q = np.copy(new_Q)
      num_iterations += 1
    self.V = np.max(self.Q, axis=-1)
    print(f'Q and V found in {num_iterations} iterations with an error tolerance of {error_tolerance}.')

  def plan(self):
    """Now planning is just doing an argmax over the Q-values!
    """
    #################################################
    # Note that we're going back to argmax, since the Q-values now represent proper
    # "returns-to-go", so we want to maximize that.
    # Write this method! It should be a one-liner, and very similar to what you
    # used for extracting V from Q.
    raise NotImplementedError("Implement `plan` function!")
    #################################################
    self.pi = ...

  def _draw_v(self):
    """Draw the V values."""
    min_v = np.min(self.V)
    max_v = np.max(self.V)
    wall_v = 2 * min_v - max_v  # Creating a smaller value for walls.
    grid_values = np.ones_like(self.grid_world.world_spec, dtype=np.int32) * wall_v
    # Fill in the V values in grid cells.
    for s in range(self.num_states):
      cell = self.state_to_cell[s]
      grid_values[cell[0], cell[1]] = self.V[s]

    fig, ax = plt.subplots()
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    grid = ax.matshow(grid_values)
    grid.set_clim(wall_v, max_v)
    fig.colorbar(grid)

  def draw(self, draw_mode: str = 'grid'):
    """Draw the GridWorld according to specified mode.

    Args:
      draw_mode: Specification of what mode to draw. Supported options:
                 'grid': Draw the base GridWorld.
                 'policy': Display the policy.
                 'values': Display the values for each state.
    """
    # First make sure we convert our MDP policy into the GridWorld policy.
    if draw_mode == 'values':
      self._draw_v()
    else:
      super().draw(draw_mode == 'policy')

解答を見る$

mdpVi = MDPValueIteration(gwb)
print('GridWorld:')
mdpVi.draw()
# Compute Q, then extract policy from it.
mdpVi.computeQ()
mdpVi.plan()
print('Optimal policy:')
mdpVi.draw('policy')
mdpVi.draw('values')
# @title Submit your feedback
content_review(f"{feedback_prefix}_Implement_Value_Iteration_Exercise")

セクション4: 方策反復

QQ値やVV値の推定を収束まで反復する代わりに、方策π\pi自体を直接反復してみてはどうでしょう?

方策反復はまさにそれを行い、価値反復よりも少ないステップで解に到達することもあります。

# @title Video 6: Policy iteration
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'Z77kVYiRm_4'), ('Bilibili', 'BV17u411b7R1')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Policy_Iteration_Video")

コーディング演習5: 方策反復を実装しよう

方策反復を行う新しいMDPクラスを作成してください。

class MDPPolicyIteration(MDPToGo):

  def __init__(self, grid_world: GridWorldBase, gamma: float = 0.99):
    """Constructs an MDP from a GridWorldBase object.

    States should be numbered from left-to-right and from top-to-bottom.

    Args:
      grid_world: GridWorld specification.
      gamma: Discount factor.
    """
    super().__init__(grid_world)
    self.gamma = gamma

  def findPi(self):
    """Find π policy.
    """
    self.Q = np.zeros((self.num_states, self.num_actions))
    self.pi = np.zeros(self.num_states, dtype=np.int32)
    num_iterations = 0
    #################################################
    # Compute π, which involves computing Q.
    # Once you have π and Q, find V.
    # Hint: Your value iteration solution will be useful here.
    raise NotImplementedError("Implement `findPi` function!")
    #################################################
    new_pi = ...  # initialize to ones
    while np.any(new_pi != self.pi):
      new_pi = ...
      new_Q = ...  # initialize to zeros
      next_V = ...
      for a in range(self.num_actions):
        new_Q[:, a] = ...
      self.Q = np.copy(new_Q)
      self.pi = ...
      num_iterations += 1
    self.V = ...
    print(f'Q and V found in {num_iterations} iterations.')

  def _draw_v(self):
    """Draw the V values."""
    min_v = np.min(self.V)
    max_v = np.max(self.V)
    wall_v = 2 * min_v - max_v  # Creating a smaller value for walls.
    grid_values = np.ones_like(self.grid_world.world_spec, dtype=np.int32) * wall_v
    # Fill in the V values in grid cells.
    for s in range(self.num_states):
      cell = self.state_to_cell[s]
      grid_values[cell[0], cell[1]] = self.V[s]

    fig, ax = plt.subplots()
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    grid = ax.matshow(grid_values)
    grid.set_clim(wall_v, max_v)
    fig.colorbar(grid)

  def draw(self, draw_mode: str = 'grid'):
    """Draw the GridWorld according to specified mode.

    Args:
      draw_mode: Specification of what mode to draw. Supported options:
                 'grid': Draw the base GridWorld.
                 'policy': Display the policy.
                 'values': Display the values for each state.
    """
    # First make sure we convert our MDP policy into the GridWorld policy.
    if draw_mode == 'values':
      self._draw_v()
    else:
      super().draw(draw_mode == 'policy')

解答を見る$

mdpPi = MDPPolicyIteration(gwb)
print('GridWorld:')
mdpPi.draw()
# Compute Q, then extract policy from it.
mdpPi.findPi()
print('Optimal policy:')
mdpPi.draw('policy')
mdpPi.draw('values')
# @title Submit your feedback
content_review(f"{feedback_prefix}_Implement_Policy_Iteration_Exercise")

セクション5: QQ学習アルゴリズム

RLではPPRRも利用できないため、価値反復や方策反復を使って最適行動を見つけることはできません。

しかしQQ学習はベルマンバックアップをオンライン学習アルゴリズムに組み込み、穏やかな条件下で真のQQ値に収束することが示されています!

# @title Video 7: Q-learning
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'HqOdNTdpppE'), ('Bilibili', 'BV1az4y1n7pb')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Q_learning_Video")

コーディング演習6: QQ学習を実装しよう

QQ学習クラスを作成してください。

class QLearner(MDPValueIteration):

  def __init__(self, grid_world: GridWorldBase, gamma: float = 0.99):
    """Constructs an MDP from a GridWorldBase object.

    States should be numbered from left-to-right and from top-to-bottom.

    Args:
      grid_world: GridWorld specification.
      gamma: Discount factor.
    """
    super().__init__(grid_world, gamma)
    self.Q = np.zeros((self.num_states, self.num_actions))
    # Pick an initial state randomly.
    self.current_state = np.random.choice(self.num_states)

  def step(self, action: int) -> Tuple[int, float]:
    """Take a step in MDP from self.current_state.

    Args:
      action: Action to take.

    Returns:
      Next state and reward received.
    """
    new_state = np.random.choice(self.num_states,
                                 p=self.P[self.current_state, action, :])
    return (new_state, self.R[self.current_state, action])

  def pickAction(self) -> int:
    """Pick the best action from the current state and Q-value estimates."""
    return np.argmax(self.Q[self.current_state, :])

  def maybeReset(self):
    """If current_state is goal, reset to a random state."""
    if self.current_state == self.goal_state:
      self.current_state = np.random.choice(self.num_states)

  def learnQ(self, alpha: float = 0.1, max_steps: int = 10_000):
    """Learn the Q-function by interacting with the environment.

    Args:
      alpha: Learning rate.
      max_steps: Maximum number of steps to take.
    """
    self.Q = np.zeros((self.num_states, self.num_actions))
    num_steps = 0
    #################################################
    # Hint: Use the step(), pickAction(), and maybeReset() functions above.
    # Note: The way you initialize the Q-values is crucial here. Try first with
    # an all-zeros initialization (as is currently coded below). If it doesn't
    # work, try a different initialization.
    # Hint: The maximum possible value (given the rewards are in [0, 1]) is
    #       1 / (1 - gamma).
    raise NotImplementedError("Write `learnQ` function!")
    #################################################
    while num_steps < max_steps:
      a = ...
      new_state, r = ...
      td = ...
      self.Q[self.current_state, a] += ...
      self.current_state = ...
      self.maybeReset()
      num_steps += 1
    self.V = ...


  def plan(self):
    """Now planning is just doing an argmin over the Q-values!

    Note that this is a little different than standard Q-learning (where we do
    an argmax), since our Q-values currently store steps-to-go.
    """
    self.pi = np.argmax(self.Q, axis=-1)
    self.V = np.max(self.Q, axis=-1)

解答を見る$

base_q_learner = QLearner(gwb)
print('GridWorld:')
base_q_learner.draw()
# Compute Q, then extract policy from it.
base_q_learner.learnQ()
base_q_learner.plan()
print('Optimal policy:')
base_q_learner.draw('policy')
base_q_learner.draw('values')
# @title Submit your feedback
content_review(f"{feedback_prefix}_Implement_Q_learning_Exercise")

セクション6: ϵ\epsilon-グリーディー活用

エージェントは現在の推定値と方策を活用すべきか、それともより良い方策が存在するかもしれないので環境を探索すべきか?この探索と活用のジレンマはRLの中心的な問題です。

このセクションでは、このトレードオフのための最もシンプルで効果的な方法の一つ、いわゆるϵ\epsilon-グリーディー探索を紹介します。

# @title Video 8: Epsilon-greedy exploration
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'NiF7_RK_4M4'), ('Bilibili', 'BV1Jh4y1u78c')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Epsilon_greedy_exploration_Video")

コーディング演習7: ϵ\epsilon-グリーディー探索を実装しよう

ϵ\epsilon-グリーディー探索を用いたQQ学習クラスを作成してください。

class QLearnerExplorer(QLearner):

  def __init__(self, grid_world: GridWorldBase, gamma: float = 0.99,
               epsilon: float = 0.1):
    """Constructs an MDP from a GridWorldBase object.

    States should be numbered from left-to-right and from top-to-bottom.

    Args:
      grid_world: GridWorld specification.
      gamma: Discount factor.
      epsilon: Exploration rate.
    """
    super().__init__(grid_world, gamma)
    self.epsilon = epsilon

  def pickAction(self):
    """Pick the next action from the current state.
    """
    #################################################
    # With probability epsilon will pick the next action randomly, otherwise will
    # pick based on the Q-value estimates.
    # Hint: It should only be a few lines of code!
    raise NotImplementedError("Write the `pickAction` function!")
    #################################################
    if ... < ...:
      return ...
    return ...

解答を見る$

explorer = QLearnerExplorer(gwb)
print('GridWorld:')
explorer.draw()
# Compute Q, then extract policy from it.
explorer.learnQ()
explorer.plan()
print('Optimal policy:')
explorer.draw('policy')
explorer.draw('values')
# @title Submit your feedback
content_review(f"{feedback_prefix}_Implement_epsilon_greedy_exploration_Exercise")

まとめ

強化学習は、エージェントが試行錯誤を通じて複雑な環境で意思決定を学習できるようにするため、人工知能において重要です。エージェントの学習方法やその学習を促進するさまざまなアルゴリズムを理解することは、効果的なRLシステムを開発する上で不可欠です。