チュートリアル 3: 画像、条件付き拡散とその先へ
第2週、第4日目: 今日のテーマ
Neuromatch Academy 提供
コンテンツ作成者: Binxu Wang
コンテンツレビュアー: Shaonan Wang, Dongrui Deng, Dora Zhiyu Yang, Adrita Das
コンテンツ編集者: Shaonan Wang
制作編集者: Spiros Chavlis
チュートリアルの目標
- 拡散生成モデルの基本的な考え方を理解する:スコア関数と拡散過程の逆転。
- データのノイズ除去によってスコア関数を学習する方法を学ぶ。
- スコアを学習して特定の分布を生成する実践的な経験を積む。
# @title Tutorial slides
from IPython.display import IFrame
link_id = "j89qg"
print(f"If you want to download the slides: https://osf.io/download/{link_id}/")
IFrame(src=f"https://mfr.ca-1.osf.io/render?url=https://osf.io/{link_id}/?direct%26mode=render%26action=download%26mode=render", width=854, height=480)
セットアップ
# @title Install dependencies
# @markdown **WARNING**: There may be *errors* and/or *warnings* reported during the installation. However, they are to be ignored.
# @title Install and import feedback gadget
from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
return DatatopsContentReviewContainer(
"", # No text prompt
notebook_section,
{
"url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
"name": "neuromatch_dl",
"user_key": "f379rz8y",
},
).render()
feedback_prefix = "W2D4_T3"
# Imports
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from tqdm.notebook import trange, tqdm
from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR
from torchvision.utils import make_grid
# @title Figure settings
import ipywidgets as widgets # interactive display
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/content-creation/main/nma.mplstyle")
# @title Set random seed
# @markdown Executing `set_seed(seed=seed)` you are setting the seed
# For DL its critical to set the random seed so that students can have a
# baseline to compare their results to expected results.
# Read more here: https://pytorch.org/docs/stable/notes/randomness.html
# Call `set_seed` function in the exercises to ensure reproducibility.
import random
import torch
def set_seed(seed=None, seed_torch=True):
"""
Function that controls randomness.
NumPy and random modules must be imported.
Args:
seed : Integer
A non-negative integer that defines the random state. Default is `None`.
seed_torch : Boolean
If `True` sets the random seed for pytorch tensors, so pytorch module
must be imported. Default is `True`.
Returns:
Nothing.
"""
if seed is None:
seed = np.random.choice(2 ** 32)
random.seed(seed)
np.random.seed(seed)
if seed_torch:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print(f'Random seed {seed} has been set.')
# In case that `DataLoader` is used
def seed_worker(worker_id):
"""
DataLoader will reseed workers following randomness in
multi-process data loading algorithm.
Args:
worker_id: integer
ID of subprocess to seed. 0 means that
the data will be loaded in the main process
Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details
Returns:
Nothing
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
# @title Set device (GPU or CPU). Execute `set_device()`
# Inform the user if the notebook uses GPU or CPU.
def set_device():
"""
Set the device. CUDA if available, CPU otherwise
Args:
None
Returns:
Nothing
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "cuda":
print("WARNING: For this notebook to perform best, "
"if possible, in the menu under `Runtime` -> "
"`Change runtime type.` select `GPU` ")
else:
print("GPU is enabled in this notebook.")
return device
DEVICE = set_device()
SEED = 2021
set_seed(seed=SEED)
ニューラルネットワークのアーキテクチャ
我々はちょうど拡散モデルの基本原理を学びました。ポイントは、スコア関数によって純粋なノイズを興味深いデータ分布に変換できるということです。さらに、スコア関数をノイズ除去スコアマッチングを通じてニューラルネットワークで近似します。しかし、画像を扱う際には、ニューラルネットワークが画像と『うまく連携』し、画像に関連する帰納的バイアスを反映する必要があります。
合理的な選択肢として、ニューラルネットワークのアーキテクチャを**U-Net**にすることが挙げられます。これはCNNに似たアーキテクチャで、以下の特徴があります:
- 画像の異なる空間スケールの特徴を処理するためのダウンスケーリング/アップスケーリング操作。
- 情報の高速道路としてのスキップ接続。
我々が学習しようとしているスコア関数は時間の関数でもあるため、ニューラルネットワークが時間の変化に適切に応答する方法も考案する必要があります。この目的のために、**時間埋め込み(time embedding)**を使用できます。
# @title Video 1: Network architecture
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display
class PlayVideo(IFrame):
def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
self.id = id
if source == 'Bilibili':
src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
elif source == 'Osf':
src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
super(PlayVideo, self).__init__(src, width, height, **kwargs)
def display_videos(video_ids, W=400, H=300, fs=1):
tab_contents = []
for i, video_id in enumerate(video_ids):
out = widgets.Output()
with out:
if video_ids[i][0] == 'Youtube':
video = YouTubeVideo(id=video_ids[i][1], width=W,
height=H, fs=fs, rel=0)
print(f'Video available at https://youtube.com/watch?v={video.id}')
else:
video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
height=H, fs=fs, autoplay=False)
if video_ids[i][0] == 'Bilibili':
print(f'Video available at https://www.bilibili.com/video/{video.id}')
elif video_ids[i][0] == 'Osf':
print(f'Video available at https://osf.io/{video.id}')
display(video)
tab_contents.append(out)
return tab_contents
video_ids = [('Youtube', 'sV-ROEAZaO0'), ('Bilibili', 'BV1Yk4y1N7Ai')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Network_Architecture_Video")
コーディング演習1:MNISTのための拡散モデルの訓練
最後に、MNISTデータセットのための実際の画像拡散モデルを実装し、訓練しましょう。
スコア近似器のニューラルネットワークアーキテクチャを調べることで、我々が組み込んだ帰納的バイアスを理解できます。
次のセルでは、順方向過程のためのヘルパー関数を実装します。
- のための
marginal_prob_std(注意:分散ではなく標準偏差です) - のための
diffusion_coeff
順方向過程の数学的復習:
前回のチュートリアルと同じ順方向過程(分散爆発型SDE)を使用します。これは次のように表されます:
ここで拡散係数を とし、 とします。
この場合、初期状態 が与えられたときの時刻 における状態 の周辺分布はガウス分布 となります。分散は拡散係数の二乗の積分です。
def marginal_prob_std(t, Lambda, device='cpu'):
"""Compute the standard deviation of $p_{0t}(x(t) | x(0))$.
Args:
t: A vector of time steps.
Lambda: The $\lambda$ in our SDE.
Returns:
std : The standard deviation.
"""
t = t.to(device)
#################################################
## TODO for students: Implement the standard deviation
raise NotImplementedError("Student exercise: Implement the standard deviation")
#################################################
std = ...
return std
def diffusion_coeff(t, Lambda, device='cpu'):
"""Compute the diffusion coefficient of our SDE.
Args:
t: A vector of time steps.
Lambda: The $\lambda$ in our SDE.
Returns:
diff_coeff : The vector of diffusion coefficients.
"""
#################################################
## TODO for students: Implement the diffusion coefficients
raise NotImplementedError("Student exercise: Implement the diffusion coefficients")
#################################################
diff_coeff = ...
return diff_coeff.to(device)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Train_Diffusion_for_MNIST_Exercise")
ネットワークアーキテクチャ
以下はシンプルな時間埋め込みと変調レイヤーのコードです。基本的に、時間 はサインとコサインの基底として多重化され、その後線形読み出しによって時間変調信号が生成されます。
# @title Time embedding and modulation
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
def __init__(self, embed_dim, scale=30.):
super().__init__()
# Randomly sample weights (frequencies) during initialization.
# These weights (frequencies) are fixed during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
# Cosine(2 pi freq x), Sine(2 pi freq x)
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
class Dense(nn.Module):
"""A fully connected layer that reshapes outputs to feature maps.
Allow time repr to input additively from the side of a convolution layer.
"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.dense = nn.Linear(input_dim, output_dim)
def forward(self, x):
# this broadcast the 2d tensor to 4d, add the same value across space.
return self.dense(x)[..., None, None]
以下はシンプルなU-Netアーキテクチャのコードです。拡散モデルはアーキテクチャの細部によって成功度が異なる場合があります。この例は主に説明目的です。
# @title Time-dependent UNet score model
class UNet(nn.Module):
"""A time-dependent score-based model built upon U-Net architecture."""
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
"""Initialize a time-dependent score-based network.
Args:
marginal_prob_std: A function that takes time t and gives the standard
deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
channels: The number of channels for feature maps of each resolution.
embed_dim: The dimensionality of Gaussian random feature embeddings.
"""
super().__init__()
# Gaussian random feature embedding layer for time
self.time_embed = nn.Sequential(
GaussianFourierProjection(embed_dim=embed_dim),
nn.Linear(embed_dim, embed_dim)
)
# Encoding layers where the resolution decreases
self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
self.t_mod1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
self.t_mod2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
self.t_mod3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
self.t_mod4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
# Decoding layers where the resolution increases
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
self.t_mod5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
self.t_mod6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
self.t_mod7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
# The swish activation function
self.act = lambda x: x * torch.sigmoid(x)
# A restricted version of the `marginal_prob_std` function, after specifying a Lambda.
self.marginal_prob_std = marginal_prob_std
def forward(self, x, t, y=None):
# Obtain the Gaussian random feature embedding for t
embed = self.act(self.time_embed(t))
# Encoding path, downsampling
## Incorporate information from t
h1 = self.conv1(x) + self.t_mod1(embed)
## Group normalization and apply activation function
h1 = self.act(self.gnorm1(h1))
# 2nd conv
h2 = self.conv2(h1) + self.t_mod2(embed)
h2 = self.act(self.gnorm2(h2))
# 3rd conv
h3 = self.conv3(h2) + self.t_mod3(embed)
h3 = self.act(self.gnorm3(h3))
# 4th conv
h4 = self.conv4(h3) + self.t_mod4(embed)
h4 = self.act(self.gnorm4(h4))
# Decoding path up sampling
h = self.tconv4(h4) + self.t_mod5(embed)
## Skip connection from the encoding path
h = self.act(self.tgnorm4(h))
h = self.tconv3(torch.cat([h, h3], dim=1)) + self.t_mod6(embed)
h = self.act(self.tgnorm3(h))
h = self.tconv2(torch.cat([h, h2], dim=1)) + self.t_mod7(embed)
h = self.act(self.tgnorm2(h))
h = self.tconv1(torch.cat([h, h1], dim=1))
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
考えてみよう!1: U-Netアーキテクチャ
U-Netアーキテクチャを見て、以下の操作に対応するモジュールを見つけられますか?
- 空間特徴のダウンサンプリング?
- 空間特徴のアップサンプリング?
- ダウンブランチからアップブランチへのスキップ接続はどのように実装されている?
- 時間変調はどのように実装されている?
- 出力が で割られているのはなぜ?これがスコア学習にどのように役立つか、あるいは害になるかもしれない?
2分間静かに考え、その後グループで議論しましょう(約10分)。
# @title Submit your feedback
content_review(f"{feedback_prefix}_UNet_Architecture_Discussion")
コーディング演習2: 損失関数の定義
次のセルでは、前回のチュートリアルで使ったデノイジングスコアマッチング(DSM)目的関数を実装します。
ここで時間の重み付けは と選ばれており、低ノイズ期間()よりも高ノイズ期間()を強調しています。
ヒント:
- 前回との主な違いは、スコア 、ノイズ 、状態 がすべてバッチ画像形状のテンソルであるため、 を適切にブロードキャストする必要があることです。例えば
std[:, None, None, None]が役立ちます。 epsは非常に小さなノイズスケールのスコア関数学習を防ぐために設定されています。非常に不規則になるためです。
def loss_fn(model, x, marginal_prob_std, eps=1e-3, device='cpu'):
"""The loss function for training score-based generative models.
Args:
model: A PyTorch model instance that represents a
time-dependent score-based model.
Note, it takes two inputs in its forward function model(x, t)
$s_\theta(x,t)$ in the equation
x: A mini-batch of training data.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel, takes `t` as input.
$\sigma_t$ in the equation.
eps: A tolerance value for numerical stability.
"""
# Sample time uniformly in eps, 1
random_t = torch.rand(x.shape[0], device=device) * (1. - eps) + eps
# Find the noise std at the time `t`
std = marginal_prob_std(random_t).to(device)
#################################################
## TODO for students: Implement the denoising score matching eq.
raise NotImplementedError("Student exercise: Implement the denoising score matching eq. ")
#################################################
# get normally distributed noise N(0, I)
z = ...
# compute the perturbed x = x + z * \sigma_t
perturbed_x = ...
# predict score with the model at (perturbed x, t)
score = ...
# compute distance between the score and noise \| score * sigma_t + z \|_2^2
loss = ...
##############
return loss
正しく実装された損失関数は以下のテストに合格します。
単一の 0 データポイントを持つデータセットに対して、解析的スコアは です。この場合、解析的スコアの損失がゼロになることをテストします。
# @title Test loss function
marginal_prob_std_test = lambda t: marginal_prob_std(t, Lambda=10, device='cpu')
score_analyt_test = lambda x_t, t: - x_t / marginal_prob_std_test(t)[:,None,None,None]**2
x_test = torch.zeros(10, 3, 64, 64)
loss = loss_fn(score_analyt_test, x_test, marginal_prob_std_test, eps=1e-3, device='cpu')
assert torch.allclose(loss,torch.zeros(1)), "the loss should be zero in this case"
# @title Submit your feedback
content_review(f"{feedback_prefix}_Defining_the_loss_function_Exercise")
拡散モデルの訓練とテスト
注意: n_epochs を12に減らしていますが、必要に応じて増やしても構いません。元の値は100でしたが、訓練に時間がかかる場合は n_epochs=50、batch_size=1024 でも十分です。平均損失が約30程度であれば、許容できる数字を生成できます。
# @title Training the model
Lambda = 25.0 # @param {'type':'number'}
marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device=DEVICE)
diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=Lambda, device=DEVICE)
score_model = UNet(marginal_prob_std=marginal_prob_std_fn)
score_model = score_model.to(DEVICE)
n_epochs = 12 # @param {'type':'integer'}
# size of a mini-batch
batch_size = 1024 # @param {'type':'integer'}
# learning rate
lr = 10e-4 # @param {'type':'number'}
set_seed(SEED)
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
g = torch.Generator()
g.manual_seed(SEED)
data_loader = DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=2,
worker_init_fn=seed_worker,
generator=g,)
optimizer = Adam(score_model.parameters(), lr=lr)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 1 - epoch / n_epochs))
tqdm_epoch = trange(n_epochs)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
pbar = tqdm(data_loader)
for x, y in pbar:
x = x.to(DEVICE)
loss = loss_fn(score_model, x, marginal_prob_std_fn, eps=0.01, device=DEVICE)
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
scheduler.step()
print(f"Average Loss: {(avg_loss / num_items):5f} lr {scheduler.get_last_lr()[0]:.1e}")
# Print the averaged training loss so far.
tqdm_epoch.set_description(f'Average Loss: {(avg_loss / num_items):.5f}')
# Update the checkpoint after each epoch of training.
torch.save(score_model.state_dict(), 'ckpt.pth')
# @title Define the Sampler
def Euler_Maruyama_sampler(score_model,
marginal_prob_std,
diffusion_coeff,
batch_size=64,
x_shape=(1, 28, 28),
num_steps=500,
device='cuda',
eps=1e-3, y=None):
"""Generate samples from score-based models with the Euler-Maruyama solver.
Args:
score_model: A PyTorch model that represents the time-dependent score-based model.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel.
diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
batch_size: The number of samplers to generate by calling this function once.
num_steps: The number of sampling steps.
Equivalent to the number of discretized time steps.
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
eps: The smallest time step for numerical stability.
Returns:
Samples.
"""
t = torch.ones(batch_size).to(device)
r = torch.randn(batch_size, *x_shape).to(device)
init_x = r * marginal_prob_std(t)[:, None, None, None]
init_x = init_x.to(device)
time_steps = torch.linspace(1., eps, num_steps).to(device)
step_size = time_steps[0] - time_steps[1]
x = init_x
with torch.no_grad():
for time_step in tqdm(time_steps):
batch_time_step = torch.ones(batch_size, device=device) * time_step
g = diffusion_coeff(batch_time_step)
mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step, y=y) * step_size
x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
# Do not include any noise in the last sampling step.
return mean_x
# @title Sampling
def save_samples_uncond(score_model, suffix="", device='cpu'):
score_model.eval()
## Generate samples using the specified sampler.
sample_batch_size = 64 # @param {'type':'integer'}
num_steps = 250 # @param {'type':'integer'}
# score_model.eval()
## Generate samples using the specified sampler.
samples = Euler_Maruyama_sampler(score_model,
marginal_prob_std_fn,
diffusion_coeff_fn,
sample_batch_size,
num_steps=num_steps,
device=DEVICE,
eps=0.001)
# Sample visualization.
samples = samples.clamp(0.0, 1.0)
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
sample_np = sample_grid.permute(1, 2, 0).cpu().numpy()
plt.imsave(f"uncondition_diffusion{suffix}.png", sample_np, )
plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(sample_np, vmin=0., vmax=1.)
plt.show()
marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device=DEVICE)
uncond_score_model = UNet(marginal_prob_std=marginal_prob_std_fn)
uncond_score_model.load_state_dict(torch.load("ckpt.pth"))
uncond_score_model.to(DEVICE)
save_samples_uncond(uncond_score_model, suffix="", device=DEVICE)
よくできました!これでディフュージョンモデルのトレーニングが完了しました。ご覧の通り、結果は理想的とは言えず、多くの要因が影響しています。いくつか挙げると:
- より良いネットワークアーキテクチャ:残差接続、アテンション機構、より良いアップサンプリング機構
- より良い目的関数:より良い重み付け関数
- より良い最適化手法:学習率減衰の利用
- より良いサンプリングアルゴリズム:オイラー積分は誤差が大きいことで知られているため、より高度なSDEまたはODEソルバーの使用が推奨されます
セクション2:条件付きディフュージョンモデル
結果を大幅に改善するもう一つの方法は条件付き信号を加えることです。例えば、スコアネットワークにどの数字を生成したいかを伝えることです。これによりスコアモデリングがずっと楽になり、ユーザーに制御性をもたらします。人気のあるStable Diffusionモデルはこのタイプの一つで、画像の条件信号として自然言語テキストを使用しています。
# @title Video 2: Conditional Diffusion Model
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display
class PlayVideo(IFrame):
def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
self.id = id
if source == 'Bilibili':
src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
elif source == 'Osf':
src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
super(PlayVideo, self).__init__(src, width, height, **kwargs)
def display_videos(video_ids, W=400, H=300, fs=1):
tab_contents = []
for i, video_id in enumerate(video_ids):
out = widgets.Output()
with out:
if video_ids[i][0] == 'Youtube':
video = YouTubeVideo(id=video_ids[i][1], width=W,
height=H, fs=fs, rel=0)
print(f'Video available at https://youtube.com/watch?v={video.id}')
else:
video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
height=H, fs=fs, autoplay=False)
if video_ids[i][0] == 'Bilibili':
print(f'Video available at https://www.bilibili.com/video/{video.id}')
elif video_ids[i][0] == 'Osf':
print(f'Video available at https://osf.io/{video.id}')
display(video)
tab_contents.append(out)
return tab_contents
video_ids = [('Youtube', 'tJDdVN9Fnrs'), ('Bilibili', 'BV1ek4y1N7bs')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Conditional_Diffusion_Model_Video")
数式的には、条件付きディフュージョンは無条件ディフュージョンと非常に似ています。
条件付きディフュージョンモデルの構築とトレーニング方法に興味がある方は、最後のボーナス演習をご覧ください。
# @title Video 3: Advanced Techinque - Stable Diffusion
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display
class PlayVideo(IFrame):
def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
self.id = id
if source == 'Bilibili':
src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
elif source == 'Osf':
src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
super(PlayVideo, self).__init__(src, width, height, **kwargs)
def display_videos(video_ids, W=400, H=300, fs=1):
tab_contents = []
for i, video_id in enumerate(video_ids):
out = widgets.Output()
with out:
if video_ids[i][0] == 'Youtube':
video = YouTubeVideo(id=video_ids[i][1], width=W,
height=H, fs=fs, rel=0)
print(f'Video available at https://youtube.com/watch?v={video.id}')
else:
video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
height=H, fs=fs, autoplay=False)
if video_ids[i][0] == 'Bilibili':
print(f'Video available at https://www.bilibili.com/video/{video.id}')
elif video_ids[i][0] == 'Osf':
print(f'Video available at https://osf.io/{video.id}')
display(video)
tab_contents.append(out)
return tab_contents
video_ids = [('Youtube', 'HBLgRqxgxrY'), ('Bilibili', 'BV1Yh4y1M74g')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Advanced_Techinque_Stable_Diffusion_Video")
インタラクティブデモ2:Stable Diffusion
このデモでは、最も強力なオープンソースのディフュージョンモデルの一つであるStable Diffusion 2.1を使って遊び、学んだことと結びつけてみます。
#@title Download the Stable Diffusion models
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, PNDMScheduler
model_id = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
# Use the PNDM scheduler as default
# pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
# Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(DEVICE)
これで想像力を解き放ち、テキストからアート作品を作成できます!
例のプロンプト:
prompt = "バン・ゴッホ風の砂漠を走るかわいい猫、トレンドのアート。"
prompt = "モネ風の星空の下で踊るバレリーナ、トレンドのアート。"
prompt = "A lovely cat running on the dessert in Van Gogh style, trending art." # @param {'type':'string'}
my_seed = 2023 # @param {'type':'integer'}
execute = False # @param {'type':'boolean'}
if execute:
image = pipe(prompt, num_inference_steps=50,
generator=torch.Generator("cuda").manual_seed(my_seed)).images[0]
image
# @title Submit your feedback
content_review(f"{feedback_prefix}_Stable_Diffusion_Interactive_Demo")
考えてみよう!2:Stable Diffusionモデルのアーキテクチャ
Stable DiffusionのU-Netと、上で定義したベビーUNetの類似点が見えますか?
アーキテクチャを調べるには、関数を異なるdeepestで使うことができます。
テキストはCLIPモデルを通じてエンコードされており、その構造もで見ることができます。これは大きなトランスフォーマーです!
2分間考えたりコードを触ったりしてから、グループで議論しましょう(約10分)。
# @title Helper function to inspect network
def recursive_print(module, prefix="", depth=0, deepest=3):
"""Simulating print(module) for torch.nn.Modules
but with depth control. Print to the `deepest` level. `deepest=0` means no print
"""
if depth == 0:
print(f"[{type(module).__name__}]")
if depth >= deepest:
return
for name, child in module.named_children():
if len([*child.named_children()]) == 0:
print(f"{prefix}({name}): {child}")
else:
if isinstance(child, nn.ModuleList):
print(f"{prefix}({name}): {type(child).__name__} len={len(child)}")
else:
print(f"{prefix}({name}): {type(child).__name__}")
recursive_print(child, prefix=prefix + " ", depth=depth + 1, deepest=deepest)
recursive_print(pipe.unet,deepest=2)
recursive_print(pipe.text_encoder,deepest=4)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Architecture_of_Stable_Diffusion_Model_Discussion")
セクション3:倫理的考慮事項
# @title Video 4: Ethical Consideration
from ipywidgets import widgets
from IPython.display import YouTubeVideo
from IPython.display import IFrame
from IPython.display import display
class PlayVideo(IFrame):
def __init__(self, id, source, page=1, width=400, height=300, **kwargs):
self.id = id
if source == 'Bilibili':
src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'
elif source == 'Osf':
src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'
super(PlayVideo, self).__init__(src, width, height, **kwargs)
def display_videos(video_ids, W=400, H=300, fs=1):
tab_contents = []
for i, video_id in enumerate(video_ids):
out = widgets.Output()
with out:
if video_ids[i][0] == 'Youtube':
video = YouTubeVideo(id=video_ids[i][1], width=W,
height=H, fs=fs, rel=0)
print(f'Video available at https://youtube.com/watch?v={video.id}')
else:
video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,
height=H, fs=fs, autoplay=False)
if video_ids[i][0] == 'Bilibili':
print(f'Video available at https://www.bilibili.com/video/{video.id}')
elif video_ids[i][0] == 'Osf':
print(f'Video available at https://osf.io/{video.id}')
display(video)
tab_contents.append(out)
return tab_contents
video_ids = [('Youtube', 'Qy8ODZ7TYZg'), ('Bilibili', 'BV1TV4y1a7Qx')]
tab_contents = display_videos(video_ids, W=854, H=480)
tabs = widgets.Tab()
tabs.children = tab_contents
for i in range(len(tab_contents)):
tabs.set_title(i, video_ids[i][0])
display(tabs)
# @title Submit your feedback
content_review(f"{feedback_prefix}_Ethical_Consideration_Video")
考えてみよう!3:ディフュージョン生成モデルから生成された画像の著作権
もし事前学習済みのディフュージョンモデルにアーティスト名をプロンプトとして与え、そのアーティストのスタイルに似た美しい画像が生成されたとします。この生成画像の著作権は誰にありますか?ディフュージョンモデルを作った会社、元のアーティスト、あなた(プロンプトを入力した人)、ランダムシードや重み、あるいは推論を実行したGPUでしょうか?
誰がクレジットを受けるべきだと思いますか?その理由は?
もし生成画像に十分な後処理を施した場合、例えばプロンプトやシードを微調整したり、画像を編集したりしたらどうでしょう?
2分間静かに考え、その後グループで議論しましょう(約10分)。
# @title Submit your feedback
content_review(f"{feedback_prefix}_Copyrights_Discussion")
まとめ
本日は以下について学びました。
- ディフュージョンモデリングの主要な応用例の一つ、すなわち自然画像のモデリング。
- 画像モデリングに適した帰納的バイアス:U-Netアーキテクチャと時間変調機構。
- 条件付きディフュージョンモデルの紹介と、Stable Diffusionのデモ。
- ディフュージョンモデルに関連する倫理的考慮事項、著作権、誤情報、公平性など。
ボーナス:MNISTの条件付きディフュージョンのトレーニング
このパートでは、数字に条件付けしたMNISTの生成モデルをトレーニングします。
ここでは基本的な条件付き変調の形態、すなわち数字の埋め込みを使い、特徴の相対的なゲインを線形に制御します。アテンション機構を学んだ後は、クロスアテンションを使ってスコアモデルを変調するなど、より良い条件付き変調の方法を考えることもできます。
条件付き変調を用いたUNetスコアモデル
class UNet_Conditional(nn.Module):
"""A time-dependent score-based model built upon U-Net architecture."""
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256,
text_dim=256, nClass=10):
"""Initialize a time-dependent score-based network.
Args:
marginal_prob_std: A function that takes time t and gives the standard
deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
channels: The number of channels for feature maps of each resolution.
embed_dim: The dimensionality of Gaussian random feature embeddings of time.
text_dim: the embedding dimension of text / digits.
nClass: number of classes you want to model.
"""
super().__init__()
# random embedding for classes
self.cond_embed = nn.Embedding(nClass, text_dim)
# Gaussian random feature embedding layer for time
self.time_embed = nn.Sequential(
GaussianFourierProjection(embed_dim=embed_dim),
nn.Linear(embed_dim, embed_dim)
)
# Encoding layers where the resolution decreases
self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
self.t_mod1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
self.t_mod2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
self.y_mod2 = Dense(embed_dim, channels[1])
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
self.t_mod3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.y_mod3 = Dense(embed_dim, channels[2])
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
self.t_mod4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
self.y_mod4 = Dense(embed_dim, channels[3])
# Decoding layers where the resolution increases
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
self.t_mod5 = Dense(embed_dim, channels[2])
self.y_mod5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1) # + channels[2]
self.t_mod6 = Dense(embed_dim, channels[1])
self.y_mod6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1) # + channels[1]
self.t_mod7 = Dense(embed_dim, channels[0])
self.y_mod7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1)
# The swish activation function
self.act = nn.SiLU() # lambda x: x * torch.sigmoid(x)
self.marginal_prob_std = marginal_prob_std
for module in [self.y_mod2,self.y_mod3,self.y_mod4,
self.y_mod5,self.y_mod6,self.y_mod7]:
nn.init.normal_(module.dense.weight, mean=0, std=0.0001)
nn.init.constant_(module.dense.bias, 1.0)
def forward(self, x, t, y=None):
# Obtain the Gaussian random feature embedding for t
embed = self.act(self.time_embed(t))
y_embed = self.cond_embed(y)
# Encoding path
h1 = self.conv1(x) + self.t_mod1(embed)
## Incorporate information from t
## Group normalization
h1 = self.act(self.gnorm1(h1))
h2 = self.conv2(h1) + self.t_mod2(embed)
h2 = h2 * self.y_mod2(y_embed)
h2 = self.act(self.gnorm2(h2))
h3 = self.conv3(h2) + self.t_mod3(embed)
h3 = h3 * self.y_mod3(y_embed)
h3 = self.act(self.gnorm3(h3))
h4 = self.conv4(h3) + self.t_mod4(embed)
h4 = h4 * self.y_mod4(y_embed)
h4 = self.act(self.gnorm4(h4))
# Decoding path
h = self.tconv4(h4) + self.t_mod5(embed)
h = h * self.y_mod5(y_embed)
## Skip connection from the encoding path
h = self.act(self.tgnorm4(h))
h = self.tconv3(h + h3) + self.t_mod6(embed)
h = h * self.y_mod6(y_embed)
h = self.act(self.tgnorm3(h))
h = self.tconv2(h + h2) + self.t_mod7(embed)
h = h * self.y_mod7(y_embed)
h = self.act(self.tgnorm2(h))
h = self.tconv1(h + h1)
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
条件付きディフュージョンの損失関数
def loss_fn_cond(model, x, y, marginal_prob_std, eps=1e-3):
"""The loss function for training score-based generative models.
Args:
model: A PyTorch model instance that represents a
time-dependent score-based model.
x: A mini-batch of training data.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel.
eps: A tolerance value for numerical stability.
"""
random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
z = torch.randn_like(x)
std = marginal_prob_std(random_t)
perturbed_x = x + z * std[:, None, None, None]
score = model(perturbed_x, random_t, y=y)
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2,
dim=(1, 2, 3)))
return loss
# @title Training conditional diffusion model
Lambda = 25 #@param {'type':'number'}
marginal_prob_std_fn = lambda t: marginal_prob_std(t, Lambda=Lambda, device=DEVICE)
diffusion_coeff_fn = lambda t: diffusion_coeff(t, Lambda=Lambda, device=DEVICE)
print("initilize new score model...")
score_model_cond = UNet_Conditional(marginal_prob_std=marginal_prob_std_fn)
score_model_cond = score_model_cond.to(DEVICE)
n_epochs = 10 # @param {'type':'integer'}
## size of a mini-batch
batch_size = 1024 # @param {'type':'integer'}
## learning rate
lr = 10e-4 # @param {'type':'number'}
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
optimizer = Adam(score_model_cond.parameters(), lr=lr)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 0.99 ** epoch))
tqdm_epoch = trange(n_epochs)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for x, y in tqdm(data_loader):
x = x.to(DEVICE)
loss = loss_fn_cond(score_model_cond, x, y.to(DEVICE), marginal_prob_std_fn)
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
scheduler.step()
lr_current = scheduler.get_last_lr()[0]
print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))
# Print the averaged training loss so far.
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
# Update the checkpoint after each epoch of training.
torch.save(score_model_cond.state_dict(), 'ckpt_cond.pth')
注意: 元のn_epochsの値は100でした。
# @title Sample Conditional Diffusion
digit = 4 # @param {'type':'integer'}
sample_batch_size = 64 # @param {'type':'integer'}
num_steps = 250 # @param {'type':'integer'}
score_model_cond.eval()
## Generate samples using the specified sampler.
samples = Euler_Maruyama_sampler(
score_model_cond,
marginal_prob_std_fn,
diffusion_coeff_fn,
sample_batch_size,
num_steps=num_steps,
device=DEVICE,
y=digit*torch.ones(sample_batch_size, dtype=torch.long, device=DEVICE))
## Sample visualization.
samples = samples.clamp(0.0, 1.0)
%matplotlib inline
import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
plt.figure(figsize=(6, 6))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()
torch.cuda.empty_cache()