Open In Colab   Open in Kaggle

チュートリアル 3: オートエンコーダーの応用

ボーナスデイ: オートエンコーダー

Neuromatch Academyによる

コンテンツ作成者: Marco Brigham と CCNSS チーム(2014-2018)

コンテンツレビュアー: Itzel Olivos, Karen Schroeder, Karolina Stosio, Kshitij Dwivedi, Spiros Chavlis, Michael Waskom

制作編集: Spiros Chavlis


チュートリアルの目的

オートエンコーダーの応用

豊かな内部表現を持つオートエンコーダーはMNISTの認知課題でどのように機能するか?

オートエンコーダーは見たことのない数字クラスをどのように認識するか?

ANNの画像符号化は人間の視覚とどのように異なるか?

これらの質問に答えるためのツールと技術を備えており、研究で遭遇するかもしれない多くの他の問題にも対応できることを願っています!


MNIST 認知課題$


このチュートリアルでは以下を行います:


セットアップ

# @title Install and import feedback gadget


from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch_cn",
            "user_key": "y1x3mpx5",
        },
    ).render()


feedback_prefix = "Bonus_Autoencoders_T3"
# Imports
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy import ndimage

import torch
from torch import nn, optim

from sklearn.datasets import fetch_openml
# @title Figure settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/NMA2020/nma.mplstyle")
# @title Helper functions


def downloadMNIST():
  """
  Download MNIST dataset and transform it to torch.Tensor

  Args:
    None

  Returns:
    x_train : training images (torch.Tensor) (60000, 28, 28)
    x_test  : test images (torch.Tensor) (10000, 28, 28)
    y_train : training labels (torch.Tensor) (60000, )
    y_train : test labels (torch.Tensor) (10000, )
  """
  X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
  # Trunk the data
  n_train = 60000
  n_test = 10000

  train_idx = np.arange(0, n_train)
  test_idx = np.arange(n_train, n_train + n_test)

  x_train, y_train = X[train_idx], y[train_idx]
  x_test, y_test = X[test_idx], y[test_idx]

  # Transform np.ndarrays to torch.Tensor
  x_train = torch.from_numpy(np.reshape(x_train,
                                        (len(x_train),
                                         28, 28)).astype(np.float32))
  x_test = torch.from_numpy(np.reshape(x_test,
                                       (len(x_test),
                                        28, 28)).astype(np.float32))

  y_train = torch.from_numpy(y_train.astype(int))
  y_test = torch.from_numpy(y_test.astype(int))

  return (x_train, y_train, x_test, y_test)


def init_weights_kaiming_uniform(layer):
  """
  Initializes weights from linear PyTorch layer
  with kaiming uniform distribution.

  Args:
    layer (torch.Module)
        Pytorch layer

  Returns:
    Nothing.
  """
  # check for linear PyTorch layer
  if isinstance(layer, nn.Linear):
    # initialize weights with kaiming uniform distribution
    nn.init.kaiming_uniform_(layer.weight.data)


def init_weights_kaiming_normal(layer):
  """
  Initializes weights from linear PyTorch layer
  with kaiming normal distribution.

  Args:
    layer (torch.Module)
        Pytorch layer

  Returns:
    Nothing.
  """
  # check for linear PyTorch layer
  if isinstance(layer, nn.Linear):
    # initialize weights with kaiming normal distribution
    nn.init.kaiming_normal_(layer.weight.data)


def get_layer_weights(layer):
  """
  Retrieves learnable parameters from PyTorch layer.

  Args:
    layer (torch.Module)
        Pytorch layer

  Returns:
    list with learnable parameters
  """
  # initialize output list
  weights = []

  # check whether layer has learnable parameters
  if layer.parameters():
    # copy numpy array representation of each set of learnable parameters
    for item in layer.parameters():
      weights.append(item.detach().numpy())

  return weights


def eval_mse(y_pred, y_true):
  """
  Evaluates mean square error (MSE) between y_pred and y_true

  Args:
    y_pred (torch.Tensor)
        prediction samples

    v (numpy array of floats)
        ground truth samples

  Returns:
    MSE(y_pred, y_true)
  """

  with torch.no_grad():
    criterion = nn.MSELoss()
    loss = criterion(y_pred, y_true)

  return float(loss)


def eval_bce(y_pred, y_true):
  """
  Evaluates binary cross-entropy (BCE) between y_pred and y_true

  Args:
    y_pred (torch.Tensor)
        prediction samples

    v (numpy array of floats)
        ground truth samples

  Returns:
    BCE(y_pred, y_true)
  """

  with torch.no_grad():
    criterion = nn.BCELoss()
    loss = criterion(y_pred, y_true)

  return float(loss)


def plot_row(images, show_n=10, image_shape=None):
  """
  Plots rows of images from list of iterables (iterables: list, numpy array
  or torch.Tensor). Also accepts single iterable.
  Randomly selects images in each list element if item count > show_n.

  Args:
    images (iterable or list of iterables)
        single iterable with images, or list of iterables

    show_n (integer)
        maximum number of images per row

    image_shape (tuple or list)
        original shape of image if vectorized form

  Returns:
    Nothing.
  """

  if not isinstance(images, (list, tuple)):
    images = [images]

  for items_idx, items in enumerate(images):

    items = np.array(items)
    if items.ndim == 1:
      items = np.expand_dims(items, axis=0)

    if len(items) > show_n:
      selected = np.random.choice(len(items), show_n, replace=False)
      items = items[selected]

    if image_shape is not None:
      items = items.reshape([-1] + list(image_shape))

    plt.figure(figsize=(len(items) * 1.5, 2))
    for image_idx, image in enumerate(items):

      plt.subplot(1, len(items), image_idx + 1)
      plt.imshow(image, cmap='gray', vmin=image.min(), vmax=image.max())
      plt.axis('off')

    plt.tight_layout()


def to_s2(u):
  """
  Projects 3D coordinates to spherical coordinates (theta, phi) surface of
  unit sphere S2.
  theta: [0, pi]
  phi: [-pi, pi]

  Args:
    u (list, numpy array or torch.Tensor of floats)
        3D coordinates

  Returns:
    Spherical coordinates (theta, phi) on surface of unit sphere S2.
  """

  x, y, z = (u[:, 0], u[:, 1], u[:, 2])
  r = np.sqrt(x**2 + y**2 + z**2)
  theta = np.arccos(z / r)
  phi = np.arctan2(x, y)

  return np.array([theta, phi]).T


def to_u3(s):
  """
  Converts from 2D coordinates on surface of unit sphere S2 to 3D coordinates
  (on surface of S2), i.e. (theta, phi) ---> (1, theta, phi).

  Args:
    s (list, numpy array or torch.Tensor of floats)
        2D coordinates on unit sphere S_2

  Returns:
    3D coordinates on surface of unit sphere S_2
  """

  theta, phi = (s[:, 0], s[:, 1])
  x = np.sin(theta) * np.sin(phi)
  y = np.sin(theta) * np.cos(phi)
  z = np.cos(theta)

  return np.array([x, y, z]).T


def xy_lim(x):
  """
  Return arguments for plt.xlim and plt.ylim calculated from minimum
  and maximum of x.

  Args:
    x (list, numpy array or torch.Tensor of floats)
        data to be plotted

  Returns:
    Nothing.
  """
  x_min = np.min(x, axis=0)
  x_max = np.max(x, axis=0)

  x_min = x_min - np.abs(x_max - x_min) * 0.05 - np.finfo(float).eps
  x_max = x_max + np.abs(x_max - x_min) * 0.05 + np.finfo(float).eps

  return [x_min[0], x_max[0]], [x_min[1], x_max[1]]


def plot_generative(x, decoder_fn, image_shape, n_row=16, s2=False):
  """
  Plots images reconstructed by decoder_fn from a 2D grid in
  latent space that is determined by minimum and maximum values in x.

  Args:
    x (list, numpy array or torch.Tensor of floats)
        2D or 3D coordinates in latent space

    decoder_fn (integer)
        function returning vectorized images from 2D latent space coordinates

    image_shape (tuple or list)
        original shape of image

    n_row (integer)
        number of rows in grid

    s2 (boolean)
        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)

  Returns:
    Nothing.
  """

  if s2:
    x = to_s2(np.array(x))

  xlim, ylim = xy_lim(np.array(x))

  dx = (xlim[1] - xlim[0]) / n_row
  grid = [np.linspace(ylim[0] + dx / 2, ylim[1] - dx / 2, n_row),
          np.linspace(xlim[0] + dx / 2, xlim[1] - dx / 2, n_row)]

  canvas = np.zeros((image_shape[0] * n_row, image_shape[1] * n_row))

  cmap = plt.get_cmap('gray')

  for j, latent_y in enumerate(grid[0][::-1]):
    for i, latent_x in enumerate(grid[1]):

      latent = np.array([[latent_x, latent_y]], dtype=np.float32)

      if s2:
        latent = to_u3(latent)

      with torch.no_grad():
        x_decoded = decoder_fn(torch.from_numpy(latent))

      x_decoded = x_decoded.reshape(image_shape)

      canvas[j * image_shape[0]: (j + 1) * image_shape[0],
             i * image_shape[1]: (i + 1) * image_shape[1]] = x_decoded

  plt.imshow(canvas, cmap=cmap, vmin=canvas.min(), vmax=canvas.max())
  plt.axis('off')


def plot_latent(x, y, show_n=500, s2=False, fontdict=None, xy_labels=None):
  """
  Plots digit class of each sample in 2D latent space coordinates.

  Args:
    x (list, numpy array or torch.Tensor of floats)
        2D coordinates in latent space

    y (list, numpy array or torch.Tensor of floats)
        digit class of each sample

    n_row (integer)
        number of samples

    s2 (boolean)
        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)

    fontdict (dictionary)
        style option for plt.text

    xy_labels (list)
        optional list with [xlabel, ylabel]

  Returns:
    Nothing.
  """

  if fontdict is None:
    fontdict = {'weight': 'bold', 'size': 12}

  if s2:
    x = to_s2(np.array(x))

  cmap = plt.get_cmap('tab10')

  if len(x) > show_n:
    selected = np.random.choice(len(x), show_n, replace=False)
    x = x[selected]
    y = y[selected]

  for my_x, my_y in zip(x, y):
    plt.text(my_x[0], my_x[1], str(int(my_y)),
             color=cmap(int(my_y) / 10.),
             fontdict=fontdict,
             horizontalalignment='center',
             verticalalignment='center',
             alpha=0.8)

  xlim, ylim = xy_lim(np.array(x))
  plt.xlim(xlim)
  plt.ylim(ylim)

  if s2:
    if xy_labels is None:
      xy_labels = [r'$\varphi$', r'$\theta$']

    plt.xticks(np.arange(0, np.pi + np.pi / 6, np.pi / 6),
               ['0', '$\pi/6$', '$\pi/3$', '$\pi/2$',
                '$2\pi/3$', '$5\pi/6$', '$\pi$'])
    plt.yticks(np.arange(-np.pi, np.pi + np.pi / 3, np.pi / 3),
               ['$-\pi$', '$-2\pi/3$', '$-\pi/3$', '0',
                '$\pi/3$', '$2\pi/3$', '$\pi$'])

  if xy_labels is None:
    xy_labels = ['$Z_1$', '$Z_2$']

  plt.xlabel(xy_labels[0])
  plt.ylabel(xy_labels[1])


def plot_latent_generative(x, y, decoder_fn, image_shape, s2=False,
                           title=None, xy_labels=None):
  """
  Two horizontal subplots generated with encoder map and decoder grid.

  Args:
    x (list, numpy array or torch.Tensor of floats)
        2D coordinates in latent space

    y (list, numpy array or torch.Tensor of floats)
        digit class of each sample

    decoder_fn (integer)
        function returning vectorized images from 2D latent space coordinates

    image_shape (tuple or list)
        original shape of image

    s2 (boolean)
        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)

    title (string)
        plot title

    xy_labels (list)
        optional list with [xlabel, ylabel]

  Returns:
    Nothing.
  """

  fig = plt.figure(figsize=(12, 6))

  if title is not None:
    fig.suptitle(title, y=1.05)

  ax = fig.add_subplot(121)
  ax.set_title('Encoder map', y=1.05)
  plot_latent(x, y, s2=s2, xy_labels=xy_labels)

  ax = fig.add_subplot(122)
  ax.set_title('Decoder grid', y=1.05)
  plot_generative(x, decoder_fn, image_shape, s2=s2)

  plt.tight_layout()
  plt.show()


def plot_latent_ab(x1, x2, y, selected_idx=None,
                   title_a='Before', title_b='After', show_n=500, s2=False):
  """
  Two horizontal subplots with encoder maps.

  Args:
    x1 (list, numpy array or torch.Tensor of floats)
        2D coordinates in latent space (left plot)

    x2 (list, numpy array or torch.Tensor of floats)
        digit class of each sample (right plot)

    y (list, numpy array or torch.Tensor of floats)
        digit class of each sample

    selected_idx (list of integers)
        indexes of elements to be plotted

    show_n (integer)
        maximum number of samples in each plot

    s2 (boolean)
        convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)

  Returns:
    Nothing.
  """

  fontdict = {'weight': 'bold', 'size': 12}

  if len(x1) > show_n:

    if selected_idx is None:
      selected_idx = np.random.choice(len(x1), show_n, replace=False)

    x1 = x1[selected_idx]
    x2 = x2[selected_idx]
    y = y[selected_idx]

  data = np.concatenate([x1, x2])

  if s2:
    xlim, ylim = xy_lim(to_s2(data))

  else:
    xlim, ylim = xy_lim(data)

  plt.figure(figsize=(12, 6))

  ax = plt.subplot(121)
  ax.set_title(title_a, y=1.05)
  plot_latent(x1, y, fontdict=fontdict, s2=s2)
  plt.xlim(xlim)
  plt.ylim(ylim)

  ax = plt.subplot(122)
  ax.set_title(title_b, y=1.05)
  plot_latent(x2, y, fontdict=fontdict, s2=s2)
  plt.xlim(xlim)
  plt.ylim(ylim)
  plt.tight_layout()


def runSGD(net, input_train, input_test, out_train=None, out_test=None,
           optimizer=None, criterion='bce', n_epochs=10, batch_size=32,
           verbose=False):
  """
  Trains autoencoder network with stochastic gradient descent with
  optimizer and loss criterion. Train samples are shuffled, and loss is
  displayed at the end of each opoch for both MSE and BCE. Plots training loss
  at each minibatch (maximum of 500 randomly selected values).

  Args:
    net (torch network)
        ANN network (nn.Module)

    input_train (torch.Tensor)
        vectorized input images from train set

    input_test (torch.Tensor)
        vectorized input images from test set

    criterion (string)
        train loss: 'bce' or 'mse'

    out_train (torch.Tensor)
        optional target images from train set

    out_test (torch.Tensor)
        optional target images from test set

    optimizer (torch optimizer)
        optional target images from train set

    criterion (string)
        train loss: 'bce' or 'mse'

    n_epochs (boolean)
        number of full iterations of training data

    batch_size (integer)
        number of element in mini-batches

    verbose (boolean)
        whether to print final loss

  Returns:
    Nothing.
  """

  if out_train is not None and out_test is not None:
    different_output = True
  else:
    different_output = False

  # Initialize loss function
  if criterion == 'mse':
    loss_fn = nn.MSELoss()
  elif criterion == 'bce':
    loss_fn = nn.BCELoss()
  else:
    print('Please specify either "mse" or "bce" for loss criterion')

  # Initialize SGD optimizer
  if optimizer is None:
    optimizer = optim.Adam(net.parameters())

  # Placeholder for loss
  track_loss = []

  print('Epoch', '\t', 'Loss train', '\t', 'Loss test')
  for i in range(n_epochs):

    shuffle_idx = np.random.permutation(len(input_train))
    batches = torch.split(input_train[shuffle_idx], batch_size)

    if different_output:
      batches_out = torch.split(out_train[shuffle_idx], batch_size)

    for batch_idx, batch in enumerate(batches):

      output_train = net(batch)

      if different_output:
        loss = loss_fn(output_train, batches_out[batch_idx])
      else:
        loss = loss_fn(output_train, batch)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      # Keep track of loss at each epoch
      track_loss += [float(loss)]

    loss_epoch = f'{i+1}/{n_epochs}'
    with torch.no_grad():
      output_train = net(input_train)
      if different_output:
        loss_train = loss_fn(output_train, out_train)
      else:
        loss_train = loss_fn(output_train, input_train)

      loss_epoch += f'\t {loss_train:.4f}'

      output_test = net(input_test)
      if different_output:
        loss_test = loss_fn(output_test, out_test)
      else:
        loss_test = loss_fn(output_test, input_test)

      loss_epoch += f'\t\t {loss_test:.4f}'

    print(loss_epoch)

  if verbose:
    # Print loss
    if different_output:
      loss_mse = f'\nMSE\t {eval_mse(output_train, out_train):0.4f}'
      loss_mse += f'\t\t {eval_mse(output_test, out_test):0.4f}'
    else:
      loss_mse = f'\nMSE\t {eval_mse(output_train, input_train):0.4f}'
      loss_mse += f'\t\t {eval_mse(output_test, input_test):0.4f}'
    print(loss_mse)

    if different_output:
      loss_bce = f'BCE\t {eval_bce(output_train, out_train):0.4f}'
      loss_bce += f'\t\t {eval_bce(output_test, out_test):0.4f}'
    else:
      loss_bce = f'BCE\t {eval_bce(output_train, input_train):0.4f}'
      loss_bce += f'\t\t {eval_bce(output_test, input_test):0.4f}'
    print(loss_bce)

  # Plot loss
  step = int(np.ceil(len(track_loss)/500))
  x_range = np.arange(0, len(track_loss), step)
  plt.figure()
  plt.plot(x_range, track_loss[::step], 'C0')
  plt.xlabel('Iterations')
  plt.ylabel('Loss')
  plt.xlim([0, None])
  plt.ylim([0, None])
  plt.show()


def image_occlusion(x, image_shape):
  """
  Randomly selects on quadrant of images and sets to zeros.

  Args:
    x (torch.Tensor of floats)
        vectorized images

    image_shape (tuple or list)
        original shape of image

  Returns:
    torch.Tensor.
  """

  selection = np.random.choice(4, len(x))

  my_x = np.array(x).copy()
  my_x = my_x.reshape(-1, image_shape[0], image_shape[1])

  my_x[selection == 0, :int(image_shape[0] / 2), :int(image_shape[1] / 2)] = 0
  my_x[selection == 1, int(image_shape[0] / 2):, :int(image_shape[1] / 2)] = 0
  my_x[selection == 2, :int(image_shape[0] / 2), int(image_shape[1] / 2):] = 0
  my_x[selection == 3, int(image_shape[0] / 2):, int(image_shape[1] / 2):] = 0

  my_x = my_x.reshape(x.shape)

  return torch.from_numpy(my_x)


def image_rotation(x, deg, image_shape):
  """
  Randomly rotates images by +- deg degrees.

  Args:
    x (torch.Tensor of floats)
        vectorized images

    deg (integer)
        rotation range

    image_shape (tuple or list)
        original shape of image

  Returns:
    torch.Tensor.
  """

  my_x = np.array(x).copy()
  my_x = my_x.reshape(-1, image_shape[0], image_shape[1])

  for idx, item in enumerate(my_x):
    my_deg = deg * 2 * np.random.random() - deg
    my_x[idx] = ndimage.rotate(my_x[idx], my_deg,
                               reshape=False, prefilter=False)

  my_x = my_x.reshape(x.shape)

  return torch.from_numpy(my_x)


class AutoencoderClass(nn.Module):
  """
  Deep autoencoder network object (nn.Module) with optional L2 normalization
  of activations in bottleneck layer.

  Args:
    input_size (integer)
        size of input samples

    s2 (boolean)
        whether to L2 normalize activatinos in bottleneck layer

  Returns:
    Autoencoder object inherited from nn.Module class.
  """

  def __init__(self, input_size=784, s2=False):

    super().__init__()

    self.input_size = input_size
    self.s2 = s2

    if s2:
      self.encoding_size = 3

    else:
      self.encoding_size = 2

    self.enc1 = nn.Linear(self.input_size, int(self.input_size / 2))
    self.enc1_f = nn.PReLU()
    self.enc2 = nn.Linear(int(self.input_size / 2), self.encoding_size * 32)
    self.enc2_f = nn.PReLU()
    self.enc3 = nn.Linear(self.encoding_size * 32, self.encoding_size)
    self.enc3_f = nn.PReLU()
    self.dec1 = nn.Linear(self.encoding_size, self.encoding_size * 32)
    self.dec1_f = nn.PReLU()
    self.dec2 = nn.Linear(self.encoding_size * 32, int(self.input_size / 2))
    self.dec2_f = nn.PReLU()
    self.dec3 = nn.Linear(int(self.input_size / 2), self.input_size)
    self.dec3_f = nn.Sigmoid()

  def encoder(self, x):
    """
    Encoder component.
    """
    x = self.enc1_f(self.enc1(x))
    x = self.enc2_f(self.enc2(x))
    x = self.enc3_f(self.enc3(x))

    if self.s2:
        x = nn.functional.normalize(x, p=2, dim=1)

    return x

  def decoder(self, x):
    """
    Decoder component.
    """
    x = self.dec1_f(self.dec1(x))
    x = self.dec2_f(self.dec2(x))
    x = self.dec3_f(self.dec3(x))

    return x

  def forward(self, x):
    """
    Forward pass.
    """
    x = self.encoder(x)
    x = self.decoder(x)

    return x


def save_checkpoint(net, optimizer, filename):
  """
  Saves a PyTorch checkpoint.

  Args:
    net (torch network)
        ANN network (nn.Module)

    optimizer (torch optimizer)
        optimizer for SGD

    filename (string)
        filename (without extension)

  Returns:
    Nothing.
  """

  torch.save({'model_state_dict': net.state_dict(),
              'optimizer_state_dict': optimizer.state_dict()},
             filename+'.pt')


def load_checkpoint(url, filename):
  """
  Loads a PyTorch checkpoint from URL is local file not present.

  Args:
    url (string)
        URL location of PyTorch checkpoint

    filename (string)
        filename (without extension)

  Returns:
    PyTorch checkpoint of saved model.
  """

  if not os.path.isfile(filename+'.pt'):
    os.system(f"wget {url}.pt")

  return torch.load(filename+'.pt')


def reset_checkpoint(net, optimizer, checkpoint):
  """
  Resets PyTorch model to checkpoint.

  Args:
    net (torch network)
        ANN network (nn.Module)

    optimizer (torch optimizer)
        optimizer for SGD

    checkpoint (torch checkpoint)
        checkpoint of saved model

  Returns:
    Nothing.
  """

  net.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

セクション 0: はじめに

# @title Video 1: Applications
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', '_bzW_jkH6l0'), ('Bilibili', 'BV12v411q7nS')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Applications_Video")

セクション 1: MNISTデータセットのダウンロードと準備

ヘルパー関数 downloadMNIST を使ってデータセットをダウンロードし、torch.Tensor に変換して訓練セットとテストセットをそれぞれ (x_train, y_train) と (x_test, y_test) に割り当てます。

変数 input_size は、訓練用とテスト用の画像 input_traininput_testベクトル化されたバージョンの長さを格納します。

指示:

# Download MNIST
x_train, y_train, x_test, y_test = downloadMNIST()

x_train = x_train / 255
x_test = x_test / 255

image_shape = x_train.shape[1:]

input_size = np.prod(image_shape)

input_train = x_train.reshape([-1, input_size])
input_test = x_test.reshape([-1, input_size])

test_selected_idx = np.random.choice(len(x_test), 10, replace=False)
train_selected_idx = np.random.choice(len(x_train), 10, replace=False)

test_subset_idx = np.random.choice(len(x_test), 500, replace=False)

print(f'shape image \t\t {image_shape}')
print(f'shape input_train \t {input_train.shape}')
print(f'shape input_test \t {input_test.shape}')

セクション 2: 事前学習済みモデルのダウンロード

クラス AutoencoderClass は前回のチュートリアルで紹介したオートエンコーダーのアーキテクチャを実装しています。このクラスの設計はチュートリアル W3D4 のオブジェクト指向プログラミング(OOP)スタイルに従っています。ブールパラメータ s2=True を設定すると、S2S_2 球面への射影を持つモデルが指定されます。

両モデルを n_epochs=25 で訓練し、長時間の初期訓練を避けるために重みを保存しました。これらが参照モデルの状態となります。

実験はすべて同一の初期条件から開始し、各演習の開始時にオートエンコーダーを参照状態にリセットします。

PyTorchでモデルを保存・読み込みする仕組みは以下の通りです:

model = nn.Sequential(...)

または

model = AutoencoderClass()

そして

torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()},
           filename_path)

checkpoint = torch.load(filename_path)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

詳細はPyTorchの説明$を参照し、より複雑なモデルでは model.eval()model.train() の使い分けも確認してください。

関数 save_checkpointload_checkpointreset_checkpoint を提供しており、上記の手順を実装しGitHubリポジトリから事前学習済み重みをダウンロードします。

GitHubからのダウンロードに失敗した場合は、以下の3番目のセルのコメントを外して n_epochs=10 でモデルを訓練し、ローカルに保存してください。

指示:

root = 'https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders'
filename = 'ae_6h_prelu_bce_adam_25e_32b'
url = os.path.join(root, filename)
s2 = True

if s2:
  filename += '_s2'
  url += '_s2'
model = AutoencoderClass(s2=s2)
optimizer = optim.Adam(model.parameters())

encoder = model.encoder
decoder = model.decoder

checkpoint = load_checkpoint(url, filename)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Please uncomment and execute this cell if download if pre-trained weights fail

# model = AutoencoderClass(s2=s2)
# encoder = model.encoder
# decoder = model.decoder
# n_epochs = 10
# batch_size = 128
# runSGD(model, input_train, input_test,
#        n_epochs=n_epochs, batch_size=batch_size)
# save_checkpoint(model, optimizer, filename)
# checkpoint = load_checkpoint(url, filename)
with torch.no_grad():
  output_test = model(input_test)
  latent_test = encoder(input_test)

plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],
         image_shape=image_shape)

plot_latent_generative(latent_test, y_test, decoder,
                       image_shape=image_shape, s2=s2)

セクション 3: オートエンコーダーの応用

応用 1 - 画像ノイズ

画像に加えられたノイズを除去することは、次元削減技術でよく示されます。次元削減の日のチュートリアルではPCAでこの能力を示しました。

まず、ノイズなしの画像で訓練されたオートエンコーダーは、ノイズのある画像を入力として受け取るとノイズのない画像を出力することを観察します。ただし、再構成された画像は元の画像(ノイズなし)とは異なります。なぜなら、加えられたノイズが潜在空間の異なる座標にマッピングされるためです。

ノイズなしとノイズありのバージョンを潜在空間の類似領域にマッピングする能力は、ロバスト性またはノイズに対する不変性として知られています。オートエンコーダーにこの機能を組み込むにはどうすればよいでしょうか?

解決策は、ノイズなしとノイズありのバージョンをノイズなしバージョンにマッピングするようにオートエンコーダーを訓練することです。より速い方法は、ノイズあり画像で数エポックだけ再訓練することです。これらの短時間の訓練セッションは、ノイズあり画像を類似の潜在空間座標からノイズなしバージョンにマッピングするよう重みを微調整します。

まず、オートエンコーダーを参照状態にリセットしましょう。

指示:

reset_checkpoint(model, optimizer, checkpoint)

with torch.no_grad():
  latent_test_ref = encoder(input_test)

微調整前の再構成

ノイズなし画像で訓練されたオートエンコーダーが、ノイズのある入力からノイズなしの画像を出力することを確認しましょう。3行のプロットで可視化します:

ノイズ課題$

下段はノイズを加える前の再構成に問題があるサンプルを特定するのに役立ちます。この行は元画像ではなくこれらのサンプルの基準となる再構成品質を示しています。(なぜでしょう?)

指示:

noise_factor = 0.4

input_train_noisy = (input_train
                     + noise_factor * np.random.normal(size=input_train.shape))
input_train_noisy = np.clip(input_train_noisy, input_train.min(),
                            input_train.max(), dtype=np.float32)

input_test_noisy = (input_test
                    + noise_factor * np.random.normal(size=input_test.shape))
input_test_noisy = np.clip(input_test_noisy, input_test.min(),
                           input_test.max(), dtype=np.float32)
with torch.no_grad():
  output_test_noisy = model(input_test_noisy)
  latent_test_noisy = encoder(input_test_noisy)
  output_test = model(input_test)

plot_row([input_test_noisy[test_selected_idx],
          output_test_noisy[test_selected_idx],
          output_test[test_selected_idx]], image_shape=image_shape)

微調整前の潜在空間

入力にノイズを加えることが潜在空間の座標にどのように影響するかを調べ、再構成誤差の原因を探ります。デコーダーは座標の大きな変化を異なる数字として解釈します。

関数 plot_latent_ab は2つの条件間で同じサンプルセットの潜在空間座標を比較します。ここでは、前のセルの10サンプルのノイズなしとノイズありの座標を表示します:

指示:

plot_latent_ab(latent_test, latent_test_noisy, y_test, test_selected_idx,
               title_a='Before noise', title_b='After noise', s2=s2)

ノイズあり画像でオートエンコーダーを微調整

ノイズあり画像を入力、元のノイズなし画像を出力としてオートエンコーダーを再訓練し、前のプロットを再生成します。

ノイズありとノイズなしの画像が類似の潜在空間位置にマッチすることがわかります。ネットワークはノイズに対してよりロバストな潜在空間表現で入力をデノイズします。

指示:

n_epochs = 3
batch_size = 32

model.train()

runSGD(model, input_train_noisy, input_test_noisy,
       out_train=input_train, out_test=input_test,
       n_epochs=n_epochs, batch_size=batch_size)
with torch.no_grad():
  output_test_noisy = model(input_test_noisy)
  latent_test_noisy = encoder(input_test_noisy)
  output_test = model(input_test)

plot_row([input_test_noisy[test_selected_idx],
          output_test_noisy[test_selected_idx],
          output_test[test_selected_idx]], image_shape=image_shape)

plot_latent_ab(latent_test, latent_test_noisy, y_test, test_selected_idx,
               title_a='Before fine-tuning',
               title_b='After fine-tuning', s2=s2)

潜在空間の全体的なシフト

新しい潜在空間表現はノイズに対してよりロバストで、データセットの内部表現が改善されている可能性があります。ノイズあり画像で微調整する前後のクリーン画像の潜在空間を調べて確認します。

ノイズあり画像でネットワークを微調整すると、データセットにドメインシフト(分布の変化)が生じます。元はノイズなし画像で構成されていたためです。再訓練のタスクや変化の程度(エポック数、オプティマイザの特性など)によっては、新しい潜在空間表現が副作用として元のデータに対して適応度が低下することがあります。ドメインシフトにどう対処し、ノイズあり・なし両方の画像を改善できるでしょうか?

指示:

with torch.no_grad():
  latent_test = encoder(input_test)

plot_latent_ab(latent_test_ref, latent_test, y_test, test_subset_idx,
               title_a='Before fine-tuning',
               title_b='After fine-tuning', s2=s2)

応用 2 - 画像の部分遮蔽

次に画像の部分遮蔽の影響を調べます。前の演習から、訓練セットに遮蔽画像が含まれていないため、オートエンコーダーは完全な画像を再構成すると予想されます(そうですよね?)。

3行のプロットで可視化します:

遮蔽課題$

同様に、潜在空間における部分画像の表現と微調整後の変化を調べ、この問題の原因を探ります。

指示:

reset_checkpoint(model, optimizer, checkpoint)

with torch.no_grad():
  latent_test_ref = encoder(input_test)

微調整前

指示:

input_train_mask = image_occlusion(input_train, image_shape=image_shape)
input_test_mask = image_occlusion(input_test, image_shape=image_shape)
with torch.no_grad():
  output_test_mask = model(input_test_mask)
  latent_test_mask = encoder(input_test_mask)
  output_test = model(input_test)

plot_row([input_test_mask[test_selected_idx],
          output_test_mask[test_selected_idx],
          output_test[test_selected_idx]], image_shape=image_shape)

plot_latent_ab(latent_test, latent_test_mask, y_test, test_selected_idx,
               title_a='Before occlusion', title_b='After occlusion', s2=s2)

微調整後

n_epochs = 3
batch_size = 32

model.train()

runSGD(model, input_train_mask, input_test_mask,
       out_train=input_train, out_test=input_test,
       n_epochs=n_epochs, batch_size=batch_size)
with torch.no_grad():
  output_test_mask = model(input_test_mask)
  latent_test_mask = encoder(input_test_mask)
  output_test = model(input_test)

plot_row([input_test_mask[test_selected_idx],
          output_test_mask[test_selected_idx],
          output_test[test_selected_idx]], image_shape=image_shape)

plot_latent_ab(latent_test, latent_test_mask, y_test, test_selected_idx,
               title_a='Before fine-tuning',
               title_b='After fine-tuning', s2=s2)
with torch.no_grad():
  latent_test = encoder(input_test)

plot_latent_ab(latent_test_ref, latent_test, y_test, test_subset_idx,
               title_a='Before fine-tuning',
               title_b='After fine-tuning', s2=s2)

応用 3 - 画像の回転

最後に画像の回転が潜在空間座標に与える影響を見ます。この課題は画像再構成の完全な書き換えを必要とするため、より難しいかもしれません。

3行のプロットで可視化します:

回転課題$

回転画像の潜在空間表現と微調整後の変化を調べ、この問題の原因を探ります。

指示:

reset_checkpoint(model, optimizer, checkpoint)

with torch.no_grad():
  latent_test_ref = encoder(input_test)

微調整前

指示:

input_train_rotation = image_rotation(input_train, 90, image_shape=image_shape)
input_test_rotation = image_rotation(input_test, 90, image_shape=image_shape)
with torch.no_grad():
  output_test_rotation = model(input_test_rotation)
  latent_test_rotation = encoder(input_test_rotation)
  output_test = model(input_test)

plot_row([input_test_rotation[test_selected_idx],
          output_test_rotation[test_selected_idx],
          output_test[test_selected_idx]], image_shape=image_shape)

plot_latent_ab(latent_test, latent_test_rotation, y_test, test_selected_idx,
               title_a='Before rotation', title_b='After rotation', s2=s2)

微調整後

指示:

n_epochs = 5
batch_size = 32

model.train()

runSGD(model, input_train_rotation, input_test_rotation,
       out_train=input_train, out_test=input_test,
       n_epochs=n_epochs, batch_size=batch_size)
with torch.no_grad():
  output_test_rotation = model(input_test_rotation)
  latent_test_rotation = encoder(input_test_rotation)
  output_test = model(input_test)

plot_row([input_test_rotation[test_selected_idx],
          output_test_rotation[test_selected_idx],
          output_test[test_selected_idx]], image_shape=image_shape)

plot_latent_ab(latent_test, latent_test_rotation, y_test, test_selected_idx,
               title_a='Before fine-tuning',
               title_b='After fine-tuning', s2=s2)
with torch.no_grad():
  latent_test = encoder(input_test)

plot_latent_ab(latent_test_ref, latent_test, y_test, test_subset_idx,
               title_a='Before fine-tuning',
               title_b='After fine-tuning', s2=s2)

応用 4 - もし数字「6」を一度も見たことがなかったらどのように見えるか?

そんな不可能な課題で頭を悩ませる前に、オートエンコーダーにやらせてみましょう!

数字クラス 6 を除いてオートエンコーダーを最初から訓練し、数字 6 の再構成を可視化します。

指示:

model = AutoencoderClass(s2=s2)
optimizer = optim.Adam(model.parameters())

encoder = model.encoder
decoder = model.decoder
missing = 6

my_input_train = input_train[y_train != missing]
my_input_test = input_test[y_test != missing]
my_y_test = y_test[y_test != missing]
n_epochs = 3
batch_size = 32

runSGD(model, my_input_train, my_input_test,
       n_epochs=n_epochs, batch_size=batch_size)

with torch.no_grad():
  output_test = model(input_test)
  my_latent_test = encoder(my_input_test)
plot_row([input_test[y_test == 6], output_test[y_test == 6]],
         image_shape=image_shape)

plot_latent_generative(my_latent_test, my_y_test, decoder,
                       image_shape=image_shape, s2=s2)

コーディング演習 1: 最も支配的な数字クラスの除去

数字クラス 01 は、他の数字クラスに比べてデコーダーグリッドの大きな領域を占めているため支配的です。

最も支配的な2つの数字クラスを除去すると潜在空間はどう変わるでしょうか?残りのクラスに均等に再分布するでしょうか、それとも別の2つの支配的クラスを選ぶでしょうか?

指示:

model = AutoencoderClass(s2=s2)
optimizer = optim.Adam(model.parameters())

encoder = model.encoder
decoder = model.decoder
missing_a = 1
missing_b = 0
#####################################################################
# Fill in missing code (...),
# then remove or comment the line below to test your function
raise NotImplementedError("Complete the code elements below!")
#####################################################################
# input train data
my_input_train = ...
# input test data
my_input_test = ...
# model
my_y_test = ...

print(my_input_train.shape)
print(my_input_test.shape)
print(my_y_test.shape)

出力例

torch.Size([47335, 784])
torch.Size([7885, 784])
torch.Size([7885])

解答を見る$

n_epochs = 3
batch_size = 32

runSGD(model, my_input_train, my_input_test,
       n_epochs=n_epochs, batch_size=batch_size)

with torch.no_grad():
  output_test = model(input_test)
  my_latent_test = encoder(my_input_test)
plot_row([input_test[y_test == missing_a], output_test[y_test == missing_a]],
         image_shape=image_shape)

plot_row([input_test[y_test == missing_b], output_test[y_test == missing_b]],
         image_shape=image_shape)

plot_latent_generative(my_latent_test, my_y_test, decoder,
                       image_shape=image_shape, s2=s2)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Removing_the_most_dominant_class_Exercise")

セクション 4: ANN?同じだけど違う!

「Same same but different(同じようで違う)」は、アジアの一部で似ているはずの対象の違いを表現する言葉です。この演習では、全結合ANNが人間の視覚と比較して視覚情報を処理する根本的な違いを調べます。

前の演習ではANNオートエンコーダーが認知課題を比較的容易にこなすことを示しました。しかし、画像のベクトル化にすでにANN処理の重要な側面が符号化されています。このネットワークアーキテクチャはピクセルの相対位置を完全に無視します。これを示すために、ピクセル位置をシャッフルしても学習が同様に進むことを示します。

まず、shuffle_image_idx に格納された可逆なシャッフルマップを取得し、画像のピクセルをランダムにシャッフルします。


mnist_pixel_shuffle$


シャッフルされていない画像セット input_shuffle は以下で復元されます:

input_shuffle[:, shuffle_rev_image_idx]]

まず、可逆シャッフルマップを設定し、シャッフルされたピクセルとシャッフル解除されたピクセルの画像をいくつか可視化し、その後ノイズありバージョンも表示します。

指示:

# create forward and reverse indexes for pixel shuffling
shuffle_image_idx = np.arange(input_size)
shuffle_rev_image_idx = np.empty_like(shuffle_image_idx)

# shuffle pixel location
np.random.shuffle(shuffle_image_idx)

# store reverse locations
for pos_idx, pos in enumerate(shuffle_image_idx):
  shuffle_rev_image_idx[pos] = pos_idx

# shuffle train and test sets
input_train_shuffle = input_train[:, shuffle_image_idx]
input_test_shuffle = input_test[:, shuffle_image_idx]

input_train_shuffle_noisy = input_train_noisy[:, shuffle_image_idx]
input_test_shuffle_noisy = input_test_noisy[:, shuffle_image_idx]

# show samples with shuffled pixels
plot_row([input_test_shuffle,
          input_test_shuffle[:, shuffle_rev_image_idx]],
         image_shape=image_shape)
# show noisy samples with shuffled pixels
plot_row([input_train_shuffle_noisy[test_selected_idx],
          input_train_shuffle_noisy[:, shuffle_rev_image_idx][test_selected_idx]],
         image_shape=image_shape)

ノイズ除去タスクでシャッフルされたピクセルを使ってネットワークを初期化し訓練します。

指示:

model = AutoencoderClass(s2=s2)

encoder = model.encoder
decoder = model.decoder

n_epochs = 3
batch_size = 32

# train the model to denoise shuffled images
runSGD(model, input_train_shuffle_noisy, input_test_shuffle_noisy,
       out_train=input_train_shuffle, out_test=input_test_shuffle,
       n_epochs=n_epochs, batch_size=batch_size)

最後に、訓練済みモデルで再構成と潜在空間表現を可視化します。

再構成は3行で可視化します:

mnist_pixel_shuffle denoised$

エンコーダーマップは以前と同様の構造を示します。類似の内部表現を共有することは、ネットワークがピクセルの相対位置を無視していることを確認します。デコーダーグリッドはシャッフルされた画像を生成するため以前とは異なります。

指示:

with torch.no_grad():
  latent_test_shuffle_noisy = encoder(input_test_shuffle_noisy)
  output_test_shuffle_noisy = model(input_test_shuffle_noisy)

plot_row([input_test_shuffle_noisy[test_selected_idx],
          output_test_shuffle_noisy[test_selected_idx],
          output_test_shuffle_noisy[:, shuffle_rev_image_idx][test_selected_idx]],
         image_shape=image_shape)

plot_latent_generative(latent_test_shuffle_noisy, y_test, decoder,
                       image_shape=image_shape, s2=s2)

まとめ

おめでとうございます!NMA 2020の最後のチュートリアルを修了しました!

これらのチュートリアルを楽しんでいただき、オートエンコーダーがデータの豊かで非線形な低次元構造をモデル化するのに有用であることを学んでいただけたことを願っています。認知の特定の側面をモデル化したり、生物学的に妥当なアーキテクチャ(スパイキングニューロンのオートエンコーダーなど)に拡張したりする際に役立つかもしれません。

これらのチュートリアルからの主なメッセージは以下の通りです:

圧縮・復元やノイズ除去などの実践的学習タスクで訓練されたオートエンコーダーは、構造化された画像や他の認知的に関連するデータに埋め込まれた豊かな低次元構造を明らかにできる。

訓練時に見たデータドメインは「認知バイアス」を刻印する — 期待するものしか見えず、それは以前に見たものに似ているだけである。

このバイアスは心理学者ダニエル・カーネマンが提唱した見えるものがすべて$の概念に関連しています。

神経科学へのオートエンコーダーの追加応用については、アウトロビデオのスパイクソーティング応用を参照し、またこちらで実際のニューロンネットワークの入出力関係をオートエンコーダーで再現する方法をご覧ください。

# @title Video 2: Wrap-up
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display


class PlayVideo(IFrame):
  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
    self.id = id
    if source == 'Bilibili':
      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
    elif source == 'Osf':
      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
    super(PlayVideo, self).__init__(src, width, height, **kwargs)


def display_videos(video_ids, W=400, H=300, fs=1):
  tab_contents = []
  for i, video_id in enumerate(video_ids):
    out = widgets.Output()
    with out:
      if video_ids[i][0] == 'Youtube':
        video = YouTubeVideo(id=video_ids[i][1], width=W,
                             height=H, fs=fs, rel=0)
        print(f'Video available at https://youtube.com/watch?v={video.id}')
      else:
        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
                          height=H, fs=fs, autoplay=False)
        if video_ids[i][0] == 'Bilibili':
          print(f'Video available at https://www.bilibili.com/video/{video.id}')
        elif video_ids[i][0] == 'Osf':
          print(f'Video available at https://osf.io/{video.id}')
      display(video)
    tab_contents.append(out)
  return tab_contents


video_ids = [('Youtube', 'ziiZK9P6AXQ'), ('Bilibili', 'BV1ph411Z7uh')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
  tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_WrapUp_Video")