チュートリアル 3: オートエンコーダーの応用
ボーナスデイ: オートエンコーダー
Neuromatch Academyによる
コンテンツ作成者: Marco Brigham と CCNSS チーム(2014-2018)
コンテンツレビュアー: Itzel Olivos, Karen Schroeder, Karolina Stosio, Kshitij Dwivedi, Spiros Chavlis, Michael Waskom
制作編集: Spiros Chavlis
チュートリアルの目的
オートエンコーダーの応用
豊かな内部表現を持つオートエンコーダーはMNISTの認知課題でどのように機能するか?
オートエンコーダーは見たことのない数字クラスをどのように認識するか?
ANNの画像符号化は人間の視覚とどのように異なるか?
これらの質問に答えるためのツールと技術を備えており、研究で遭遇するかもしれない多くの他の問題にも対応できることを願っています!
$
このチュートリアルでは以下を行います:
- 変換されたデータ(ノイズ追加、部分遮蔽、回転)をオートエンコーダーがどのように認識し、短時間の再訓練でどのように変化するかを分析する
- オートエンコーダーを使って見たことのない数字クラスを可視化する
- 全結合ANNオートエンコーダーの視覚符号化を理解する
セットアップ
# @title Install and import feedback gadget
from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
return DatatopsContentReviewContainer(
"", # No text prompt
notebook_section,
{
"url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
"name": "neuromatch_cn",
"user_key": "y1x3mpx5",
},
).render()
feedback_prefix = "Bonus_Autoencoders_T3"
# Imports
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy import ndimage
import torch
from torch import nn, optim
from sklearn.datasets import fetch_openml
# @title Figure settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/NMA2020/nma.mplstyle")
# @title Helper functions
def downloadMNIST():
"""
Download MNIST dataset and transform it to torch.Tensor
Args:
None
Returns:
x_train : training images (torch.Tensor) (60000, 28, 28)
x_test : test images (torch.Tensor) (10000, 28, 28)
y_train : training labels (torch.Tensor) (60000, )
y_train : test labels (torch.Tensor) (10000, )
"""
X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
# Trunk the data
n_train = 60000
n_test = 10000
train_idx = np.arange(0, n_train)
test_idx = np.arange(n_train, n_train + n_test)
x_train, y_train = X[train_idx], y[train_idx]
x_test, y_test = X[test_idx], y[test_idx]
# Transform np.ndarrays to torch.Tensor
x_train = torch.from_numpy(np.reshape(x_train,
(len(x_train),
28, 28)).astype(np.float32))
x_test = torch.from_numpy(np.reshape(x_test,
(len(x_test),
28, 28)).astype(np.float32))
y_train = torch.from_numpy(y_train.astype(int))
y_test = torch.from_numpy(y_test.astype(int))
return (x_train, y_train, x_test, y_test)
def init_weights_kaiming_uniform(layer):
"""
Initializes weights from linear PyTorch layer
with kaiming uniform distribution.
Args:
layer (torch.Module)
Pytorch layer
Returns:
Nothing.
"""
# check for linear PyTorch layer
if isinstance(layer, nn.Linear):
# initialize weights with kaiming uniform distribution
nn.init.kaiming_uniform_(layer.weight.data)
def init_weights_kaiming_normal(layer):
"""
Initializes weights from linear PyTorch layer
with kaiming normal distribution.
Args:
layer (torch.Module)
Pytorch layer
Returns:
Nothing.
"""
# check for linear PyTorch layer
if isinstance(layer, nn.Linear):
# initialize weights with kaiming normal distribution
nn.init.kaiming_normal_(layer.weight.data)
def get_layer_weights(layer):
"""
Retrieves learnable parameters from PyTorch layer.
Args:
layer (torch.Module)
Pytorch layer
Returns:
list with learnable parameters
"""
# initialize output list
weights = []
# check whether layer has learnable parameters
if layer.parameters():
# copy numpy array representation of each set of learnable parameters
for item in layer.parameters():
weights.append(item.detach().numpy())
return weights
def eval_mse(y_pred, y_true):
"""
Evaluates mean square error (MSE) between y_pred and y_true
Args:
y_pred (torch.Tensor)
prediction samples
v (numpy array of floats)
ground truth samples
Returns:
MSE(y_pred, y_true)
"""
with torch.no_grad():
criterion = nn.MSELoss()
loss = criterion(y_pred, y_true)
return float(loss)
def eval_bce(y_pred, y_true):
"""
Evaluates binary cross-entropy (BCE) between y_pred and y_true
Args:
y_pred (torch.Tensor)
prediction samples
v (numpy array of floats)
ground truth samples
Returns:
BCE(y_pred, y_true)
"""
with torch.no_grad():
criterion = nn.BCELoss()
loss = criterion(y_pred, y_true)
return float(loss)
def plot_row(images, show_n=10, image_shape=None):
"""
Plots rows of images from list of iterables (iterables: list, numpy array
or torch.Tensor). Also accepts single iterable.
Randomly selects images in each list element if item count > show_n.
Args:
images (iterable or list of iterables)
single iterable with images, or list of iterables
show_n (integer)
maximum number of images per row
image_shape (tuple or list)
original shape of image if vectorized form
Returns:
Nothing.
"""
if not isinstance(images, (list, tuple)):
images = [images]
for items_idx, items in enumerate(images):
items = np.array(items)
if items.ndim == 1:
items = np.expand_dims(items, axis=0)
if len(items) > show_n:
selected = np.random.choice(len(items), show_n, replace=False)
items = items[selected]
if image_shape is not None:
items = items.reshape([-1] + list(image_shape))
plt.figure(figsize=(len(items) * 1.5, 2))
for image_idx, image in enumerate(items):
plt.subplot(1, len(items), image_idx + 1)
plt.imshow(image, cmap='gray', vmin=image.min(), vmax=image.max())
plt.axis('off')
plt.tight_layout()
def to_s2(u):
"""
Projects 3D coordinates to spherical coordinates (theta, phi) surface of
unit sphere S2.
theta: [0, pi]
phi: [-pi, pi]
Args:
u (list, numpy array or torch.Tensor of floats)
3D coordinates
Returns:
Spherical coordinates (theta, phi) on surface of unit sphere S2.
"""
x, y, z = (u[:, 0], u[:, 1], u[:, 2])
r = np.sqrt(x**2 + y**2 + z**2)
theta = np.arccos(z / r)
phi = np.arctan2(x, y)
return np.array([theta, phi]).T
def to_u3(s):
"""
Converts from 2D coordinates on surface of unit sphere S2 to 3D coordinates
(on surface of S2), i.e. (theta, phi) ---> (1, theta, phi).
Args:
s (list, numpy array or torch.Tensor of floats)
2D coordinates on unit sphere S_2
Returns:
3D coordinates on surface of unit sphere S_2
"""
theta, phi = (s[:, 0], s[:, 1])
x = np.sin(theta) * np.sin(phi)
y = np.sin(theta) * np.cos(phi)
z = np.cos(theta)
return np.array([x, y, z]).T
def xy_lim(x):
"""
Return arguments for plt.xlim and plt.ylim calculated from minimum
and maximum of x.
Args:
x (list, numpy array or torch.Tensor of floats)
data to be plotted
Returns:
Nothing.
"""
x_min = np.min(x, axis=0)
x_max = np.max(x, axis=0)
x_min = x_min - np.abs(x_max - x_min) * 0.05 - np.finfo(float).eps
x_max = x_max + np.abs(x_max - x_min) * 0.05 + np.finfo(float).eps
return [x_min[0], x_max[0]], [x_min[1], x_max[1]]
def plot_generative(x, decoder_fn, image_shape, n_row=16, s2=False):
"""
Plots images reconstructed by decoder_fn from a 2D grid in
latent space that is determined by minimum and maximum values in x.
Args:
x (list, numpy array or torch.Tensor of floats)
2D or 3D coordinates in latent space
decoder_fn (integer)
function returning vectorized images from 2D latent space coordinates
image_shape (tuple or list)
original shape of image
n_row (integer)
number of rows in grid
s2 (boolean)
convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)
Returns:
Nothing.
"""
if s2:
x = to_s2(np.array(x))
xlim, ylim = xy_lim(np.array(x))
dx = (xlim[1] - xlim[0]) / n_row
grid = [np.linspace(ylim[0] + dx / 2, ylim[1] - dx / 2, n_row),
np.linspace(xlim[0] + dx / 2, xlim[1] - dx / 2, n_row)]
canvas = np.zeros((image_shape[0] * n_row, image_shape[1] * n_row))
cmap = plt.get_cmap('gray')
for j, latent_y in enumerate(grid[0][::-1]):
for i, latent_x in enumerate(grid[1]):
latent = np.array([[latent_x, latent_y]], dtype=np.float32)
if s2:
latent = to_u3(latent)
with torch.no_grad():
x_decoded = decoder_fn(torch.from_numpy(latent))
x_decoded = x_decoded.reshape(image_shape)
canvas[j * image_shape[0]: (j + 1) * image_shape[0],
i * image_shape[1]: (i + 1) * image_shape[1]] = x_decoded
plt.imshow(canvas, cmap=cmap, vmin=canvas.min(), vmax=canvas.max())
plt.axis('off')
def plot_latent(x, y, show_n=500, s2=False, fontdict=None, xy_labels=None):
"""
Plots digit class of each sample in 2D latent space coordinates.
Args:
x (list, numpy array or torch.Tensor of floats)
2D coordinates in latent space
y (list, numpy array or torch.Tensor of floats)
digit class of each sample
n_row (integer)
number of samples
s2 (boolean)
convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)
fontdict (dictionary)
style option for plt.text
xy_labels (list)
optional list with [xlabel, ylabel]
Returns:
Nothing.
"""
if fontdict is None:
fontdict = {'weight': 'bold', 'size': 12}
if s2:
x = to_s2(np.array(x))
cmap = plt.get_cmap('tab10')
if len(x) > show_n:
selected = np.random.choice(len(x), show_n, replace=False)
x = x[selected]
y = y[selected]
for my_x, my_y in zip(x, y):
plt.text(my_x[0], my_x[1], str(int(my_y)),
color=cmap(int(my_y) / 10.),
fontdict=fontdict,
horizontalalignment='center',
verticalalignment='center',
alpha=0.8)
xlim, ylim = xy_lim(np.array(x))
plt.xlim(xlim)
plt.ylim(ylim)
if s2:
if xy_labels is None:
xy_labels = [r'$\varphi$', r'$\theta$']
plt.xticks(np.arange(0, np.pi + np.pi / 6, np.pi / 6),
['0', '$\pi/6$', '$\pi/3$', '$\pi/2$',
'$2\pi/3$', '$5\pi/6$', '$\pi$'])
plt.yticks(np.arange(-np.pi, np.pi + np.pi / 3, np.pi / 3),
['$-\pi$', '$-2\pi/3$', '$-\pi/3$', '0',
'$\pi/3$', '$2\pi/3$', '$\pi$'])
if xy_labels is None:
xy_labels = ['$Z_1$', '$Z_2$']
plt.xlabel(xy_labels[0])
plt.ylabel(xy_labels[1])
def plot_latent_generative(x, y, decoder_fn, image_shape, s2=False,
title=None, xy_labels=None):
"""
Two horizontal subplots generated with encoder map and decoder grid.
Args:
x (list, numpy array or torch.Tensor of floats)
2D coordinates in latent space
y (list, numpy array or torch.Tensor of floats)
digit class of each sample
decoder_fn (integer)
function returning vectorized images from 2D latent space coordinates
image_shape (tuple or list)
original shape of image
s2 (boolean)
convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)
title (string)
plot title
xy_labels (list)
optional list with [xlabel, ylabel]
Returns:
Nothing.
"""
fig = plt.figure(figsize=(12, 6))
if title is not None:
fig.suptitle(title, y=1.05)
ax = fig.add_subplot(121)
ax.set_title('Encoder map', y=1.05)
plot_latent(x, y, s2=s2, xy_labels=xy_labels)
ax = fig.add_subplot(122)
ax.set_title('Decoder grid', y=1.05)
plot_generative(x, decoder_fn, image_shape, s2=s2)
plt.tight_layout()
plt.show()
def plot_latent_ab(x1, x2, y, selected_idx=None,
title_a='Before', title_b='After', show_n=500, s2=False):
"""
Two horizontal subplots with encoder maps.
Args:
x1 (list, numpy array or torch.Tensor of floats)
2D coordinates in latent space (left plot)
x2 (list, numpy array or torch.Tensor of floats)
digit class of each sample (right plot)
y (list, numpy array or torch.Tensor of floats)
digit class of each sample
selected_idx (list of integers)
indexes of elements to be plotted
show_n (integer)
maximum number of samples in each plot
s2 (boolean)
convert 3D coordinates (x, y, z) to spherical coordinates (theta, phi)
Returns:
Nothing.
"""
fontdict = {'weight': 'bold', 'size': 12}
if len(x1) > show_n:
if selected_idx is None:
selected_idx = np.random.choice(len(x1), show_n, replace=False)
x1 = x1[selected_idx]
x2 = x2[selected_idx]
y = y[selected_idx]
data = np.concatenate([x1, x2])
if s2:
xlim, ylim = xy_lim(to_s2(data))
else:
xlim, ylim = xy_lim(data)
plt.figure(figsize=(12, 6))
ax = plt.subplot(121)
ax.set_title(title_a, y=1.05)
plot_latent(x1, y, fontdict=fontdict, s2=s2)
plt.xlim(xlim)
plt.ylim(ylim)
ax = plt.subplot(122)
ax.set_title(title_b, y=1.05)
plot_latent(x2, y, fontdict=fontdict, s2=s2)
plt.xlim(xlim)
plt.ylim(ylim)
plt.tight_layout()
def runSGD(net, input_train, input_test, out_train=None, out_test=None,
optimizer=None, criterion='bce', n_epochs=10, batch_size=32,
verbose=False):
"""
Trains autoencoder network with stochastic gradient descent with
optimizer and loss criterion. Train samples are shuffled, and loss is
displayed at the end of each opoch for both MSE and BCE. Plots training loss
at each minibatch (maximum of 500 randomly selected values).
Args:
net (torch network)
ANN network (nn.Module)
input_train (torch.Tensor)
vectorized input images from train set
input_test (torch.Tensor)
vectorized input images from test set
criterion (string)
train loss: 'bce' or 'mse'
out_train (torch.Tensor)
optional target images from train set
out_test (torch.Tensor)
optional target images from test set
optimizer (torch optimizer)
optional target images from train set
criterion (string)
train loss: 'bce' or 'mse'
n_epochs (boolean)
number of full iterations of training data
batch_size (integer)
number of element in mini-batches
verbose (boolean)
whether to print final loss
Returns:
Nothing.
"""
if out_train is not None and out_test is not None:
different_output = True
else:
different_output = False
# Initialize loss function
if criterion == 'mse':
loss_fn = nn.MSELoss()
elif criterion == 'bce':
loss_fn = nn.BCELoss()
else:
print('Please specify either "mse" or "bce" for loss criterion')
# Initialize SGD optimizer
if optimizer is None:
optimizer = optim.Adam(net.parameters())
# Placeholder for loss
track_loss = []
print('Epoch', '\t', 'Loss train', '\t', 'Loss test')
for i in range(n_epochs):
shuffle_idx = np.random.permutation(len(input_train))
batches = torch.split(input_train[shuffle_idx], batch_size)
if different_output:
batches_out = torch.split(out_train[shuffle_idx], batch_size)
for batch_idx, batch in enumerate(batches):
output_train = net(batch)
if different_output:
loss = loss_fn(output_train, batches_out[batch_idx])
else:
loss = loss_fn(output_train, batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Keep track of loss at each epoch
track_loss += [float(loss)]
loss_epoch = f'{i+1}/{n_epochs}'
with torch.no_grad():
output_train = net(input_train)
if different_output:
loss_train = loss_fn(output_train, out_train)
else:
loss_train = loss_fn(output_train, input_train)
loss_epoch += f'\t {loss_train:.4f}'
output_test = net(input_test)
if different_output:
loss_test = loss_fn(output_test, out_test)
else:
loss_test = loss_fn(output_test, input_test)
loss_epoch += f'\t\t {loss_test:.4f}'
print(loss_epoch)
if verbose:
# Print loss
if different_output:
loss_mse = f'\nMSE\t {eval_mse(output_train, out_train):0.4f}'
loss_mse += f'\t\t {eval_mse(output_test, out_test):0.4f}'
else:
loss_mse = f'\nMSE\t {eval_mse(output_train, input_train):0.4f}'
loss_mse += f'\t\t {eval_mse(output_test, input_test):0.4f}'
print(loss_mse)
if different_output:
loss_bce = f'BCE\t {eval_bce(output_train, out_train):0.4f}'
loss_bce += f'\t\t {eval_bce(output_test, out_test):0.4f}'
else:
loss_bce = f'BCE\t {eval_bce(output_train, input_train):0.4f}'
loss_bce += f'\t\t {eval_bce(output_test, input_test):0.4f}'
print(loss_bce)
# Plot loss
step = int(np.ceil(len(track_loss)/500))
x_range = np.arange(0, len(track_loss), step)
plt.figure()
plt.plot(x_range, track_loss[::step], 'C0')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.xlim([0, None])
plt.ylim([0, None])
plt.show()
def image_occlusion(x, image_shape):
"""
Randomly selects on quadrant of images and sets to zeros.
Args:
x (torch.Tensor of floats)
vectorized images
image_shape (tuple or list)
original shape of image
Returns:
torch.Tensor.
"""
selection = np.random.choice(4, len(x))
my_x = np.array(x).copy()
my_x = my_x.reshape(-1, image_shape[0], image_shape[1])
my_x[selection == 0, :int(image_shape[0] / 2), :int(image_shape[1] / 2)] = 0
my_x[selection == 1, int(image_shape[0] / 2):, :int(image_shape[1] / 2)] = 0
my_x[selection == 2, :int(image_shape[0] / 2), int(image_shape[1] / 2):] = 0
my_x[selection == 3, int(image_shape[0] / 2):, int(image_shape[1] / 2):] = 0
my_x = my_x.reshape(x.shape)
return torch.from_numpy(my_x)
def image_rotation(x, deg, image_shape):
"""
Randomly rotates images by +- deg degrees.
Args:
x (torch.Tensor of floats)
vectorized images
deg (integer)
rotation range
image_shape (tuple or list)
original shape of image
Returns:
torch.Tensor.
"""
my_x = np.array(x).copy()
my_x = my_x.reshape(-1, image_shape[0], image_shape[1])
for idx, item in enumerate(my_x):
my_deg = deg * 2 * np.random.random() - deg
my_x[idx] = ndimage.rotate(my_x[idx], my_deg,
reshape=False, prefilter=False)
my_x = my_x.reshape(x.shape)
return torch.from_numpy(my_x)
class AutoencoderClass(nn.Module):
"""
Deep autoencoder network object (nn.Module) with optional L2 normalization
of activations in bottleneck layer.
Args:
input_size (integer)
size of input samples
s2 (boolean)
whether to L2 normalize activatinos in bottleneck layer
Returns:
Autoencoder object inherited from nn.Module class.
"""
def __init__(self, input_size=784, s2=False):
super().__init__()
self.input_size = input_size
self.s2 = s2
if s2:
self.encoding_size = 3
else:
self.encoding_size = 2
self.enc1 = nn.Linear(self.input_size, int(self.input_size / 2))
self.enc1_f = nn.PReLU()
self.enc2 = nn.Linear(int(self.input_size / 2), self.encoding_size * 32)
self.enc2_f = nn.PReLU()
self.enc3 = nn.Linear(self.encoding_size * 32, self.encoding_size)
self.enc3_f = nn.PReLU()
self.dec1 = nn.Linear(self.encoding_size, self.encoding_size * 32)
self.dec1_f = nn.PReLU()
self.dec2 = nn.Linear(self.encoding_size * 32, int(self.input_size / 2))
self.dec2_f = nn.PReLU()
self.dec3 = nn.Linear(int(self.input_size / 2), self.input_size)
self.dec3_f = nn.Sigmoid()
def encoder(self, x):
"""
Encoder component.
"""
x = self.enc1_f(self.enc1(x))
x = self.enc2_f(self.enc2(x))
x = self.enc3_f(self.enc3(x))
if self.s2:
x = nn.functional.normalize(x, p=2, dim=1)
return x
def decoder(self, x):
"""
Decoder component.
"""
x = self.dec1_f(self.dec1(x))
x = self.dec2_f(self.dec2(x))
x = self.dec3_f(self.dec3(x))
return x
def forward(self, x):
"""
Forward pass.
"""
x = self.encoder(x)
x = self.decoder(x)
return x
def save_checkpoint(net, optimizer, filename):
"""
Saves a PyTorch checkpoint.
Args:
net (torch network)
ANN network (nn.Module)
optimizer (torch optimizer)
optimizer for SGD
filename (string)
filename (without extension)
Returns:
Nothing.
"""
torch.save({'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict()},
filename+'.pt')
def load_checkpoint(url, filename):
"""
Loads a PyTorch checkpoint from URL is local file not present.
Args:
url (string)
URL location of PyTorch checkpoint
filename (string)
filename (without extension)
Returns:
PyTorch checkpoint of saved model.
"""
if not os.path.isfile(filename+'.pt'):
os.system(f"wget {url}.pt")
return torch.load(filename+'.pt')
def reset_checkpoint(net, optimizer, checkpoint):
"""
Resets PyTorch model to checkpoint.
Args:
net (torch network)
ANN network (nn.Module)
optimizer (torch optimizer)
optimizer for SGD
checkpoint (torch checkpoint)
checkpoint of saved model
Returns:
Nothing.
"""
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
セクション 0: はじめに
# @title Video 1: Applications
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display
class PlayVideo(IFrame):
def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
self.id = id
if source == 'Bilibili':
src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
elif source == 'Osf':
src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
super(PlayVideo, self).__init__(src, width, height, **kwargs)
def display_videos(video_ids, W=400, H=300, fs=1):
tab_contents = []
for i, video_id in enumerate(video_ids):
out = widgets.Output()
with out:
if video_ids[i][0] == 'Youtube':
video = YouTubeVideo(id=video_ids[i][1], width=W,
height=H, fs=fs, rel=0)
print(f'Video available at https://youtube.com/watch?v={video.id}')
else:
video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
height=H, fs=fs, autoplay=False)
if video_ids[i][0] == 'Bilibili':
print(f'Video available at https://www.bilibili.com/video/{video.id}')
elif video_ids[i][0] == 'Osf':
print(f'Video available at https://osf.io/{video.id}')
display(video)
tab_contents.append(out)
return tab_contents
video_ids = [('Youtube', '_bzW_jkH6l0'), ('Bilibili', 'BV12v411q7nS')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Applications_Video")
セクション 1: MNISTデータセットのダウンロードと準備
ヘルパー関数 downloadMNIST を使ってデータセットをダウンロードし、torch.Tensor に変換して訓練セットとテストセットをそれぞれ (x_train, y_train) と (x_test, y_test) に割り当てます。
変数 input_size は、訓練用とテスト用の画像 input_train と input_test のベクトル化されたバージョンの長さを格納します。
指示:
- 以下のセルを実行してください
# Download MNIST
x_train, y_train, x_test, y_test = downloadMNIST()
x_train = x_train / 255
x_test = x_test / 255
image_shape = x_train.shape[1:]
input_size = np.prod(image_shape)
input_train = x_train.reshape([-1, input_size])
input_test = x_test.reshape([-1, input_size])
test_selected_idx = np.random.choice(len(x_test), 10, replace=False)
train_selected_idx = np.random.choice(len(x_train), 10, replace=False)
test_subset_idx = np.random.choice(len(x_test), 500, replace=False)
print(f'shape image \t\t {image_shape}')
print(f'shape input_train \t {input_train.shape}')
print(f'shape input_test \t {input_test.shape}')
セクション 2: 事前学習済みモデルのダウンロード
クラス AutoencoderClass は前回のチュートリアルで紹介したオートエンコーダーのアーキテクチャを実装しています。このクラスの設計はチュートリアル W3D4 のオブジェクト指向プログラミング(OOP)スタイルに従っています。ブールパラメータ s2=True を設定すると、 球面への射影を持つモデルが指定されます。
両モデルを n_epochs=25 で訓練し、長時間の初期訓練を避けるために重みを保存しました。これらが参照モデルの状態となります。
実験はすべて同一の初期条件から開始し、各演習の開始時にオートエンコーダーを参照状態にリセットします。
PyTorchでモデルを保存・読み込みする仕組みは以下の通りです:
model = nn.Sequential(...)
または
model = AutoencoderClass()
そして
torch.save({'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()},
filename_path)
checkpoint = torch.load(filename_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
詳細はPyTorchの説明$を参照し、より複雑なモデルでは model.eval() と model.train() の使い分けも確認してください。
関数 save_checkpoint、load_checkpoint、reset_checkpoint を提供しており、上記の手順を実装しGitHubリポジトリから事前学習済み重みをダウンロードします。
GitHubからのダウンロードに失敗した場合は、以下の3番目のセルのコメントを外して n_epochs=10 でモデルを訓練し、ローカルに保存してください。
指示:
- 以下のセルを実行してください
root = 'https://github.com/mpbrigham/colaboratory-figures/raw/master/nma/autoencoders'
filename = 'ae_6h_prelu_bce_adam_25e_32b'
url = os.path.join(root, filename)
s2 = True
if s2:
filename += '_s2'
url += '_s2'
model = AutoencoderClass(s2=s2)
optimizer = optim.Adam(model.parameters())
encoder = model.encoder
decoder = model.decoder
checkpoint = load_checkpoint(url, filename)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Please uncomment and execute this cell if download if pre-trained weights fail
# model = AutoencoderClass(s2=s2)
# encoder = model.encoder
# decoder = model.decoder
# n_epochs = 10
# batch_size = 128
# runSGD(model, input_train, input_test,
# n_epochs=n_epochs, batch_size=batch_size)
# save_checkpoint(model, optimizer, filename)
# checkpoint = load_checkpoint(url, filename)
with torch.no_grad():
output_test = model(input_test)
latent_test = encoder(input_test)
plot_row([input_test[test_selected_idx], output_test[test_selected_idx]],
image_shape=image_shape)
plot_latent_generative(latent_test, y_test, decoder,
image_shape=image_shape, s2=s2)
セクション 3: オートエンコーダーの応用
応用 1 - 画像ノイズ
画像に加えられたノイズを除去することは、次元削減技術でよく示されます。次元削減の日のチュートリアルではPCAでこの能力を示しました。
まず、ノイズなしの画像で訓練されたオートエンコーダーは、ノイズのある画像を入力として受け取るとノイズのない画像を出力することを観察します。ただし、再構成された画像は元の画像(ノイズなし)とは異なります。なぜなら、加えられたノイズが潜在空間の異なる座標にマッピングされるためです。
ノイズなしとノイズありのバージョンを潜在空間の類似領域にマッピングする能力は、ロバスト性またはノイズに対する不変性として知られています。オートエンコーダーにこの機能を組み込むにはどうすればよいでしょうか?
解決策は、ノイズなしとノイズありのバージョンをノイズなしバージョンにマッピングするようにオートエンコーダーを訓練することです。より速い方法は、ノイズあり画像で数エポックだけ再訓練することです。これらの短時間の訓練セッションは、ノイズあり画像を類似の潜在空間座標からノイズなしバージョンにマッピングするよう重みを微調整します。
まず、オートエンコーダーを参照状態にリセットしましょう。
指示:
- 以下のセルを実行してください
reset_checkpoint(model, optimizer, checkpoint)
with torch.no_grad():
latent_test_ref = encoder(input_test)
微調整前の再構成
ノイズなし画像で訓練されたオートエンコーダーが、ノイズのある入力からノイズなしの画像を出力することを確認しましょう。3行のプロットで可視化します:
- 上段はノイズのある画像の入力
- 中段はノイズのある画像の再構成
- 下段は元の画像(ノイズなし)の再構成
$
下段はノイズを加える前の再構成に問題があるサンプルを特定するのに役立ちます。この行は元画像ではなくこれらのサンプルの基準となる再構成品質を示しています。(なぜでしょう?)
指示:
- 以下のセルを実行してください
noise_factor = 0.4
input_train_noisy = (input_train
+ noise_factor * np.random.normal(size=input_train.shape))
input_train_noisy = np.clip(input_train_noisy, input_train.min(),
input_train.max(), dtype=np.float32)
input_test_noisy = (input_test
+ noise_factor * np.random.normal(size=input_test.shape))
input_test_noisy = np.clip(input_test_noisy, input_test.min(),
input_test.max(), dtype=np.float32)
with torch.no_grad():
output_test_noisy = model(input_test_noisy)
latent_test_noisy = encoder(input_test_noisy)
output_test = model(input_test)
plot_row([input_test_noisy[test_selected_idx],
output_test_noisy[test_selected_idx],
output_test[test_selected_idx]], image_shape=image_shape)
微調整前の潜在空間
入力にノイズを加えることが潜在空間の座標にどのように影響するかを調べ、再構成誤差の原因を探ります。デコーダーは座標の大きな変化を異なる数字として解釈します。
関数 plot_latent_ab は2つの条件間で同じサンプルセットの潜在空間座標を比較します。ここでは、前のセルの10サンプルのノイズなしとノイズありの座標を表示します:
- 左のプロットは元のサンプル(ノイズなし)の座標
- 右のプロットはノイズを加えた後の新しい座標
指示:
- 以下のセルを実行してください
plot_latent_ab(latent_test, latent_test_noisy, y_test, test_selected_idx,
title_a='Before noise', title_b='After noise', s2=s2)
ノイズあり画像でオートエンコーダーを微調整
ノイズあり画像を入力、元のノイズなし画像を出力としてオートエンコーダーを再訓練し、前のプロットを再生成します。
ノイズありとノイズなしの画像が類似の潜在空間位置にマッチすることがわかります。ネットワークはノイズに対してよりロバストな潜在空間表現で入力をデノイズします。
指示:
- 以下のセルを実行してください
n_epochs = 3
batch_size = 32
model.train()
runSGD(model, input_train_noisy, input_test_noisy,
out_train=input_train, out_test=input_test,
n_epochs=n_epochs, batch_size=batch_size)
with torch.no_grad():
output_test_noisy = model(input_test_noisy)
latent_test_noisy = encoder(input_test_noisy)
output_test = model(input_test)
plot_row([input_test_noisy[test_selected_idx],
output_test_noisy[test_selected_idx],
output_test[test_selected_idx]], image_shape=image_shape)
plot_latent_ab(latent_test, latent_test_noisy, y_test, test_selected_idx,
title_a='Before fine-tuning',
title_b='After fine-tuning', s2=s2)
潜在空間の全体的なシフト
新しい潜在空間表現はノイズに対してよりロバストで、データセットの内部表現が改善されている可能性があります。ノイズあり画像で微調整する前後のクリーン画像の潜在空間を調べて確認します。
ノイズあり画像でネットワークを微調整すると、データセットにドメインシフト(分布の変化)が生じます。元はノイズなし画像で構成されていたためです。再訓練のタスクや変化の程度(エポック数、オプティマイザの特性など)によっては、新しい潜在空間表現が副作用として元のデータに対して適応度が低下することがあります。ドメインシフトにどう対処し、ノイズあり・なし両方の画像を改善できるでしょうか?
指示:
- 以下のセルを実行してください
with torch.no_grad():
latent_test = encoder(input_test)
plot_latent_ab(latent_test_ref, latent_test, y_test, test_subset_idx,
title_a='Before fine-tuning',
title_b='After fine-tuning', s2=s2)
応用 2 - 画像の部分遮蔽
次に画像の部分遮蔽の影響を調べます。前の演習から、訓練セットに遮蔽画像が含まれていないため、オートエンコーダーは完全な画像を再構成すると予想されます(そうですよね?)。
3行のプロットで可視化します:
- 上段は遮蔽された画像
- 中段は遮蔽画像の再構成
- 下段は元の画像の再構成
$
同様に、潜在空間における部分画像の表現と微調整後の変化を調べ、この問題の原因を探ります。
指示:
- 以下のセルを実行してください
reset_checkpoint(model, optimizer, checkpoint)
with torch.no_grad():
latent_test_ref = encoder(input_test)
微調整前
指示:
- 以下のセルを実行してください
input_train_mask = image_occlusion(input_train, image_shape=image_shape)
input_test_mask = image_occlusion(input_test, image_shape=image_shape)
with torch.no_grad():
output_test_mask = model(input_test_mask)
latent_test_mask = encoder(input_test_mask)
output_test = model(input_test)
plot_row([input_test_mask[test_selected_idx],
output_test_mask[test_selected_idx],
output_test[test_selected_idx]], image_shape=image_shape)
plot_latent_ab(latent_test, latent_test_mask, y_test, test_selected_idx,
title_a='Before occlusion', title_b='After occlusion', s2=s2)
微調整後
n_epochs = 3
batch_size = 32
model.train()
runSGD(model, input_train_mask, input_test_mask,
out_train=input_train, out_test=input_test,
n_epochs=n_epochs, batch_size=batch_size)
with torch.no_grad():
output_test_mask = model(input_test_mask)
latent_test_mask = encoder(input_test_mask)
output_test = model(input_test)
plot_row([input_test_mask[test_selected_idx],
output_test_mask[test_selected_idx],
output_test[test_selected_idx]], image_shape=image_shape)
plot_latent_ab(latent_test, latent_test_mask, y_test, test_selected_idx,
title_a='Before fine-tuning',
title_b='After fine-tuning', s2=s2)
with torch.no_grad():
latent_test = encoder(input_test)
plot_latent_ab(latent_test_ref, latent_test, y_test, test_subset_idx,
title_a='Before fine-tuning',
title_b='After fine-tuning', s2=s2)
応用 3 - 画像の回転
最後に画像の回転が潜在空間座標に与える影響を見ます。この課題は画像再構成の完全な書き換えを必要とするため、より難しいかもしれません。
3行のプロットで可視化します:
- 上段は回転した画像
- 中段は回転画像の再構成
- 下段は元の画像の再構成
$
回転画像の潜在空間表現と微調整後の変化を調べ、この問題の原因を探ります。
指示:
- 以下のセルを実行してください
reset_checkpoint(model, optimizer, checkpoint)
with torch.no_grad():
latent_test_ref = encoder(input_test)
微調整前
指示:
- 以下のセルを実行してください
input_train_rotation = image_rotation(input_train, 90, image_shape=image_shape)
input_test_rotation = image_rotation(input_test, 90, image_shape=image_shape)
with torch.no_grad():
output_test_rotation = model(input_test_rotation)
latent_test_rotation = encoder(input_test_rotation)
output_test = model(input_test)
plot_row([input_test_rotation[test_selected_idx],
output_test_rotation[test_selected_idx],
output_test[test_selected_idx]], image_shape=image_shape)
plot_latent_ab(latent_test, latent_test_rotation, y_test, test_selected_idx,
title_a='Before rotation', title_b='After rotation', s2=s2)
微調整後
指示:
- 以下のセルを実行してください
n_epochs = 5
batch_size = 32
model.train()
runSGD(model, input_train_rotation, input_test_rotation,
out_train=input_train, out_test=input_test,
n_epochs=n_epochs, batch_size=batch_size)
with torch.no_grad():
output_test_rotation = model(input_test_rotation)
latent_test_rotation = encoder(input_test_rotation)
output_test = model(input_test)
plot_row([input_test_rotation[test_selected_idx],
output_test_rotation[test_selected_idx],
output_test[test_selected_idx]], image_shape=image_shape)
plot_latent_ab(latent_test, latent_test_rotation, y_test, test_selected_idx,
title_a='Before fine-tuning',
title_b='After fine-tuning', s2=s2)
with torch.no_grad():
latent_test = encoder(input_test)
plot_latent_ab(latent_test_ref, latent_test, y_test, test_subset_idx,
title_a='Before fine-tuning',
title_b='After fine-tuning', s2=s2)
応用 4 - もし数字「6」を一度も見たことがなかったらどのように見えるか?
そんな不可能な課題で頭を悩ませる前に、オートエンコーダーにやらせてみましょう!
数字クラス 6 を除いてオートエンコーダーを最初から訓練し、数字 6 の再構成を可視化します。
指示:
- 以下のセルを実行してください
model = AutoencoderClass(s2=s2)
optimizer = optim.Adam(model.parameters())
encoder = model.encoder
decoder = model.decoder
missing = 6
my_input_train = input_train[y_train != missing]
my_input_test = input_test[y_test != missing]
my_y_test = y_test[y_test != missing]
n_epochs = 3
batch_size = 32
runSGD(model, my_input_train, my_input_test,
n_epochs=n_epochs, batch_size=batch_size)
with torch.no_grad():
output_test = model(input_test)
my_latent_test = encoder(my_input_test)
plot_row([input_test[y_test == 6], output_test[y_test == 6]],
image_shape=image_shape)
plot_latent_generative(my_latent_test, my_y_test, decoder,
image_shape=image_shape, s2=s2)
コーディング演習 1: 最も支配的な数字クラスの除去
数字クラス 0 と 1 は、他の数字クラスに比べてデコーダーグリッドの大きな領域を占めているため支配的です。
最も支配的な2つの数字クラスを除去すると潜在空間はどう変わるでしょうか?残りのクラスに均等に再分布するでしょうか、それとも別の2つの支配的クラスを選ぶでしょうか?
指示:
- 以下のセルを実行してください
- 条件による2つのブール配列の交差は
x[(cond_a)&(cond_b)]と指定します
model = AutoencoderClass(s2=s2)
optimizer = optim.Adam(model.parameters())
encoder = model.encoder
decoder = model.decoder
missing_a = 1
missing_b = 0
#####################################################################
# Fill in missing code (...),
# then remove or comment the line below to test your function
raise NotImplementedError("Complete the code elements below!")
#####################################################################
# input train data
my_input_train = ...
# input test data
my_input_test = ...
# model
my_y_test = ...
print(my_input_train.shape)
print(my_input_test.shape)
print(my_y_test.shape)
出力例
torch.Size([47335, 784])
torch.Size([7885, 784])
torch.Size([7885])
n_epochs = 3
batch_size = 32
runSGD(model, my_input_train, my_input_test,
n_epochs=n_epochs, batch_size=batch_size)
with torch.no_grad():
output_test = model(input_test)
my_latent_test = encoder(my_input_test)
plot_row([input_test[y_test == missing_a], output_test[y_test == missing_a]],
image_shape=image_shape)
plot_row([input_test[y_test == missing_b], output_test[y_test == missing_b]],
image_shape=image_shape)
plot_latent_generative(my_latent_test, my_y_test, decoder,
image_shape=image_shape, s2=s2)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Removing_the_most_dominant_class_Exercise")
セクション 4: ANN?同じだけど違う!
「Same same but different(同じようで違う)」は、アジアの一部で似ているはずの対象の違いを表現する言葉です。この演習では、全結合ANNが人間の視覚と比較して視覚情報を処理する根本的な違いを調べます。
前の演習ではANNオートエンコーダーが認知課題を比較的容易にこなすことを示しました。しかし、画像のベクトル化にすでにANN処理の重要な側面が符号化されています。このネットワークアーキテクチャはピクセルの相対位置を完全に無視します。これを示すために、ピクセル位置をシャッフルしても学習が同様に進むことを示します。
まず、shuffle_image_idx に格納された可逆なシャッフルマップを取得し、画像のピクセルをランダムにシャッフルします。
$
シャッフルされていない画像セット input_shuffle は以下で復元されます:
input_shuffle[:, shuffle_rev_image_idx]]
まず、可逆シャッフルマップを設定し、シャッフルされたピクセルとシャッフル解除されたピクセルの画像をいくつか可視化し、その後ノイズありバージョンも表示します。
指示:
- 以下のセルを実行してください
# create forward and reverse indexes for pixel shuffling
shuffle_image_idx = np.arange(input_size)
shuffle_rev_image_idx = np.empty_like(shuffle_image_idx)
# shuffle pixel location
np.random.shuffle(shuffle_image_idx)
# store reverse locations
for pos_idx, pos in enumerate(shuffle_image_idx):
shuffle_rev_image_idx[pos] = pos_idx
# shuffle train and test sets
input_train_shuffle = input_train[:, shuffle_image_idx]
input_test_shuffle = input_test[:, shuffle_image_idx]
input_train_shuffle_noisy = input_train_noisy[:, shuffle_image_idx]
input_test_shuffle_noisy = input_test_noisy[:, shuffle_image_idx]
# show samples with shuffled pixels
plot_row([input_test_shuffle,
input_test_shuffle[:, shuffle_rev_image_idx]],
image_shape=image_shape)
# show noisy samples with shuffled pixels
plot_row([input_train_shuffle_noisy[test_selected_idx],
input_train_shuffle_noisy[:, shuffle_rev_image_idx][test_selected_idx]],
image_shape=image_shape)
ノイズ除去タスクでシャッフルされたピクセルを使ってネットワークを初期化し訓練します。
指示:
- 以下のセルを実行してください
model = AutoencoderClass(s2=s2)
encoder = model.encoder
decoder = model.decoder
n_epochs = 3
batch_size = 32
# train the model to denoise shuffled images
runSGD(model, input_train_shuffle_noisy, input_test_shuffle_noisy,
out_train=input_train_shuffle, out_test=input_test_shuffle,
n_epochs=n_epochs, batch_size=batch_size)
最後に、訓練済みモデルで再構成と潜在空間表現を可視化します。
再構成は3行で可視化します:
- 上段はシャッフルされたノイズあり画像
- 中段はシャッフルされたデノイズ画像の再構成
- 下段はデノイズされた画像のシャッフル解除再構成
$
エンコーダーマップは以前と同様の構造を示します。類似の内部表現を共有することは、ネットワークがピクセルの相対位置を無視していることを確認します。デコーダーグリッドはシャッフルされた画像を生成するため以前とは異なります。
指示:
- 以下のセルを実行してください
with torch.no_grad():
latent_test_shuffle_noisy = encoder(input_test_shuffle_noisy)
output_test_shuffle_noisy = model(input_test_shuffle_noisy)
plot_row([input_test_shuffle_noisy[test_selected_idx],
output_test_shuffle_noisy[test_selected_idx],
output_test_shuffle_noisy[:, shuffle_rev_image_idx][test_selected_idx]],
image_shape=image_shape)
plot_latent_generative(latent_test_shuffle_noisy, y_test, decoder,
image_shape=image_shape, s2=s2)
まとめ
おめでとうございます!NMA 2020の最後のチュートリアルを修了しました!
これらのチュートリアルを楽しんでいただき、オートエンコーダーがデータの豊かで非線形な低次元構造をモデル化するのに有用であることを学んでいただけたことを願っています。認知の特定の側面をモデル化したり、生物学的に妥当なアーキテクチャ(スパイキングニューロンのオートエンコーダーなど)に拡張したりする際に役立つかもしれません。
これらのチュートリアルからの主なメッセージは以下の通りです:
圧縮・復元やノイズ除去などの実践的学習タスクで訓練されたオートエンコーダーは、構造化された画像や他の認知的に関連するデータに埋め込まれた豊かな低次元構造を明らかにできる。
訓練時に見たデータドメインは「認知バイアス」を刻印する — 期待するものしか見えず、それは以前に見たものに似ているだけである。
このバイアスは心理学者ダニエル・カーネマンが提唱した見えるものがすべて$の概念に関連しています。
神経科学へのオートエンコーダーの追加応用については、アウトロビデオのスパイクソーティング応用を参照し、またこちらで実際のニューロンネットワークの入出力関係をオートエンコーダーで再現する方法をご覧ください。
# @title Video 2: Wrap-up
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display
class PlayVideo(IFrame):
def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
self.id = id
if source == 'Bilibili':
src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
elif source == 'Osf':
src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
super(PlayVideo, self).__init__(src, width, height, **kwargs)
def display_videos(video_ids, W=400, H=300, fs=1):
tab_contents = []
for i, video_id in enumerate(video_ids):
out = widgets.Output()
with out:
if video_ids[i][0] == 'Youtube':
video = YouTubeVideo(id=video_ids[i][1], width=W,
height=H, fs=fs, rel=0)
print(f'Video available at https://youtube.com/watch?v={video.id}')
else:
video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
height=H, fs=fs, autoplay=False)
if video_ids[i][0] == 'Bilibili':
print(f'Video available at https://www.bilibili.com/video/{video.id}')
elif video_ids[i][0] == 'Osf':
print(f'Video available at https://osf.io/{video.id}')
display(video)
tab_contents.append(out)
return tab_contents
video_ids = [('Youtube', 'ziiZK9P6AXQ'), ('Bilibili', 'BV1ph411Z7uh')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_WrapUp_Video")