Example: Score-Based Generative Model on MNIST#
This notebook is a shortened version of Yang Song’s Tutorial on Google Colab: https://colab.research.google.com/drive/120kYYBOVa1i0TD85RjlEkFjaWDxSFUx3?usp=sharing
Since many of the tutorial contents where already covered in the previous notebooks, I removed most of the explanations and only kept the code to test score based mondels on the MNIST dataset.
Goals#
This is a hitchhiker’s guide to score-based generative models, a family of approaches based on estimating gradients of the data distribution. They have obtained high-quality samples comparable to GANs (like below, figure from this paper) without requiring adversarial training, and are considered by some to be the new contender to GANs.
The contents of this notebook are mainly based on the following paper:
Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. “Score-Based Generative Modeling through Stochastic Differential Equations.” Internation Conference on Learning Representations, 2021
Time-Dependent Score-Based Model#
There are no restrictions on the network architecture of time-dependent score-based models, except that their output should have the same dimensionality as the input, and they should be conditioned on time.
Several useful tips on architecture choice:
It usually performs well to use the U-net architecture as the backbone of the score network \(s_\theta(\mathbf{x}, t)\),
We can incorporate the time information via Gaussian random features. Specifically, we first sample \(\omega \sim \mathcal{N}(\mathbf{0}, s^2\mathbf{I})\) which is subsequently fixed for the model (i.e., not learnable). For a time step \(t\), the corresponding Gaussian random feature is defined as \begin{align} [\sin(2\pi \omega t) ; \cos(2\pi \omega t)], \end{align} where \([\vec{a} ; \vec{b}]\) denotes the concatenation of vector \(\vec{a}\) and \(\vec{b}\). This Gaussian random feature can be used as an encoding for time step \(t\) so that the score network can condition on \(t\) by incorporating this encoding. We will see this further in the code.
We can rescale the output of the U-net by \(1/\sqrt{\mathbb{E}[\|\nabla_{\mathbf{x}}\log p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0)) \|_2^2]}\). This is because the optimal \(s_\theta(\mathbf{x}(t), t)\) has an \(\ell_2\)-norm close to \(\mathbb{E}[\|\nabla_{\mathbf{x}}\log p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0))]\|_2\), and the rescaling helps capture the norm of the true score. Recall that the training objective contains sums of the form \begin{align*} \mathbf{E}{\mathbf{x}(t) \sim p{0t}(\mathbf{x}(t) \mid \mathbf{x}(0))}[ |s_\theta(\mathbf{x}(t), t) - \nabla_{\mathbf{x}(t)}\log p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0))|_2^2]. \end{align*} Therefore, it is natural to expect that the optimal score model \(s_\theta(\mathbf{x}, t) \approx \nabla_{\mathbf{x}(t)} \log p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0))\).
Use exponential moving average (EMA) of weights when sampling. This can greatly improve sample quality, but requires slightly longer training time, and requires more work in implementation. We do not include this in this tutorial, but highly recommend it when you employ score-based generative modeling to tackle more challenging real problems.
#@title Defining a time-dependent score-based model (double click to expand or collapse)
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
# [\sin(2\pi \omega t) ; \cos(2\pi \omega t)]
def __init__(self, embed_dim, scale=30.):
super().__init__()
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
def forward(self, x):
x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
class Dense(nn.Module):
"""A fully connected layer that reshapes outputs to feature maps."""
def __init__(self, input_dim, output_dim):
super().__init__()
self.dense = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.dense(x)[..., None, None]
class ScoreNet(nn.Module):
"""A time-dependent score-based model built upon U-Net architecture."""
def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
"""Initialize a time-dependent score-based network.
Args:
marginal_prob_std: A function that takes time t and gives the standard
deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
channels: The number of channels for feature maps of each resolution.
embed_dim: The dimensionality of Gaussian random feature embeddings.
"""
super().__init__()
# Gaussian random feature embedding layer for time
self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
nn.Linear(embed_dim, embed_dim))
# Encoding layers where the resolution decreases
self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
self.dense1 = Dense(embed_dim, channels[0])
self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
self.dense2 = Dense(embed_dim, channels[1])
self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
self.dense3 = Dense(embed_dim, channels[2])
self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
self.dense4 = Dense(embed_dim, channels[3])
self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
# Decoding layers where the resolution increases
self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
self.dense5 = Dense(embed_dim, channels[2])
self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
self.dense6 = Dense(embed_dim, channels[1])
self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
self.dense7 = Dense(embed_dim, channels[0])
self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
# The swish activation function
self.act = lambda x: x * torch.sigmoid(x)
self.marginal_prob_std = marginal_prob_std
def forward(self, x, t):
# Obtain the Gaussian random feature embedding for t
embed = self.act(self.embed(t)) # embed is the GaussianFourierProjection(t)
# Encoding path
h1 = self.conv1(x)
## Incorporate information from t
h1 += self.dense1(embed)
## Group normalization
h1 = self.gnorm1(h1)
h1 = self.act(h1)
h2 = self.conv2(h1)
h2 += self.dense2(embed)
h2 = self.gnorm2(h2)
h2 = self.act(h2)
h3 = self.conv3(h2)
h3 += self.dense3(embed)
h3 = self.gnorm3(h3)
h3 = self.act(h3)
h4 = self.conv4(h3)
h4 += self.dense4(embed)
h4 = self.gnorm4(h4)
h4 = self.act(h4)
# Decoding path
h = self.tconv4(h4)
## Skip connection from the encoding path
h += self.dense5(embed)
h = self.tgnorm4(h)
h = self.act(h)
h = self.tconv3(torch.cat([h, h3], dim=1))
h += self.dense6(embed)
h = self.tgnorm3(h)
h = self.act(h)
h = self.tconv2(torch.cat([h, h2], dim=1))
h += self.dense7(embed)
h = self.tgnorm2(h)
h = self.act(h)
h = self.tconv1(torch.cat([h, h1], dim=1))
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
Training with Weighted Sum of Denoising Score Matching Objectives#
#@title Set up the SDE
import functools
device = 'mps' # 'mps' #@param ['cuda', 'cpu'] {'type':'string'}
def marginal_prob_std(t, sigma):
"""Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.
Args:
t: A vector of time steps.
sigma: The $\sigma$ in our SDE.
Returns:
The standard deviation.
"""
t = torch.tensor(t, device=device)
return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))
def diffusion_coeff(t, sigma):
"""Compute the diffusion coefficient of our SDE.
Args:
t: A vector of time steps.
sigma: The $\sigma$ in our SDE.
Returns:
The vector of diffusion coefficients.
"""
return torch.tensor(sigma**t, device=device)
sigma = 25.0#@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
#@title Define the loss function (double click to expand or collapse)
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
"""The loss function for training score-based generative models.
Args:
model: A PyTorch model instance that represents a
time-dependent score-based model.
x: A mini-batch of training data.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel.
eps: A tolerance value for numerical stability.
"""
random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
z = torch.randn_like(x)
std = marginal_prob_std(random_t)
perturbed_x = x + z * std[:, None, None, None]
score = model(perturbed_x, random_t)
loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))
return loss
#@title Training (double click to expand or collapse)
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm
from tqdm.notebook import trange, tqdm
score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)
load_model = True # load model from a checkpoint after 40 epochs - trained for approx. 40 min on a M2 Max GPU
if load_model:
ckpt = torch.load('ckpt_40_epochs.pth', map_location=device)
score_model.load_state_dict(ckpt)
n_epochs = 0 #@param {'type':'integer'}
## size of a mini-batch
batch_size = 32 #@param {'type':'integer'}
## learning rate
lr=1e-4 #@param {'type':'number'}
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
optimizer = Adam(score_model.parameters(), lr=lr)
tqdm_epoch = trange(n_epochs)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for x, y in data_loader:
x = x.to(device)
loss = loss_fn(score_model, x, marginal_prob_std_fn)
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
# Print the averaged training loss so far.
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
# Update the checkpoint after each epoch of training.
torch.save(score_model.state_dict(), 'ckpt.pth')
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
13%|█████████████████████████████▋ | 1245184/9912422 [00:00<00:00, 12105734.60it/s]
25%|██████████████████████████████████████████████████████████▌ | 2457600/9912422 [00:00<00:00, 10784574.26it/s]
38%|█████████████████████████████████████████████████████████████████████████████████████████▋ | 3768320/9912422 [00:00<00:00, 11624974.45it/s]
51%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 5046272/9912422 [00:00<00:00, 12017041.28it/s]
64%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 6324224/9912422 [00:00<00:00, 12227350.66it/s]
77%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 7602176/9912422 [00:00<00:00, 12212480.45it/s]
90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 8880128/9912422 [00:00<00:00, 12342544.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 12129905.27it/s]
Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 74590944.47it/s]
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
66%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1081344/1648877 [00:00<00:00, 10711451.09it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 11102499.70it/s]
Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 8667210.54it/s]
Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Sampling with Numerical SDE Solvers#
#@title Define the Euler-Maruyama sampler (double click to expand or collapse)
## The number of sampling steps.
num_steps = 500#@param {'type':'integer'}
def Euler_Maruyama_sampler(score_model,
marginal_prob_std,
diffusion_coeff,
batch_size=64,
num_steps=num_steps,
device='cuda',
eps=1e-3):
"""Generate samples from score-based models with the Euler-Maruyama solver.
Args:
score_model: A PyTorch model that represents the time-dependent score-based model.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel.
diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
batch_size: The number of samplers to generate by calling this function once.
num_steps: The number of sampling steps.
Equivalent to the number of discretized time steps.
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
eps: The smallest time step for numerical stability.
Returns:
Samples.
"""
t = torch.ones(batch_size, device=device)
init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
* marginal_prob_std(t)[:, None, None, None]
time_steps = torch.linspace(1., eps, num_steps, device=device)
step_size = time_steps[0] - time_steps[1]
x = init_x
with torch.no_grad():
for time_step in tqdm(time_steps):
batch_time_step = torch.ones(batch_size, device=device) * time_step
g = diffusion_coeff(batch_time_step)
mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
# Do not include any noise in the last sampling step.
return mean_x
Sampling with Predictor-Corrector Methods#
#@title Define the Predictor-Corrector sampler (double click to expand or collapse)
signal_to_noise_ratio = 0.16 #@param {'type':'number'}
## The number of sampling steps.
num_steps = 500#@param {'type':'integer'}
def pc_sampler(score_model,
marginal_prob_std,
diffusion_coeff,
batch_size=64,
num_steps=num_steps,
snr=signal_to_noise_ratio,
device='cuda',
eps=1e-3):
"""Generate samples from score-based models with Predictor-Corrector method.
Args:
score_model: A PyTorch model that represents the time-dependent score-based model.
marginal_prob_std: A function that gives the standard deviation
of the perturbation kernel.
diffusion_coeff: A function that gives the diffusion coefficient
of the SDE.
batch_size: The number of samplers to generate by calling this function once.
num_steps: The number of sampling steps.
Equivalent to the number of discretized time steps.
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
eps: The smallest time step for numerical stability.
Returns:
Samples.
"""
t = torch.ones(batch_size, device=device)
init_x = torch.randn(batch_size, 1, 28, 28, device=device) * marginal_prob_std(t)[:, None, None, None]
time_steps = np.linspace(1., eps, num_steps)
step_size = time_steps[0] - time_steps[1]
x = init_x
with torch.no_grad():
for time_step in tqdm(time_steps):
batch_time_step = torch.ones(batch_size, device=device) * time_step
# Corrector step (Langevin MCMC)
grad = score_model(x, batch_time_step)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
noise_norm = np.sqrt(np.prod(x.shape[1:]))
langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
x = x + langevin_step_size * grad + torch.sqrt(2 * langevin_step_size) * torch.randn_like(x)
# Predictor step (Euler-Maruyama)
g = diffusion_coeff(batch_time_step)
x_mean = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step) * step_size
x = x_mean + torch.sqrt(g**2 * step_size)[:, None, None, None] * torch.randn_like(x)
# The last step does not include any noise
return x_mean
Sampling with Numerical ODE Solvers#
#@title Define the ODE sampler (double click to expand or collapse)
from scipy import integrate
## The error tolerance for the black-box ODE solver
error_tolerance = 1e-5 #@param {'type': 'number'}
def ode_sampler(score_model,
marginal_prob_std,
diffusion_coeff,
batch_size=64,
atol=error_tolerance,
rtol=error_tolerance,
device='cuda',
z=None,
eps=1e-3):
"""Generate samples from score-based models with black-box ODE solvers.
Args:
score_model: A PyTorch model that represents the time-dependent score-based model.
marginal_prob_std: A function that returns the standard deviation
of the perturbation kernel.
diffusion_coeff: A function that returns the diffusion coefficient of the SDE.
batch_size: The number of samplers to generate by calling this function once.
atol: Tolerance of absolute errors.
rtol: Tolerance of relative errors.
device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
z: The latent code that governs the final sample. If None, we start from p_1;
otherwise, we start from the given z.
eps: The smallest time step for numerical stability.
"""
t = torch.ones(batch_size, device=device)
# Create the latent code
if z is None:
init_x = torch.randn(batch_size, 1, 28, 28, device=device) \
* marginal_prob_std(t)[:, None, None, None]
else:
init_x = z
shape = init_x.shape
def score_eval_wrapper(sample, time_steps):
"""A wrapper of the score-based model for use by the ODE solver."""
sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
with torch.no_grad():
score = score_model(sample, time_steps)
return score.cpu().numpy().reshape((-1,)).astype(np.float64)
def ode_func(t, x):
"""The ODE function for use by the ODE solver."""
time_steps = np.ones((shape[0],)) * t
g = diffusion_coeff(torch.tensor(t, dtype=torch.float32)).cpu().numpy()
return -0.5 * (g**2) * score_eval_wrapper(x, time_steps)
# Run the black-box ODE solver.
res = integrate.solve_ivp(ode_func, (1., eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')
print(f"Number of function evaluations: {res.nfev}")
x = torch.tensor(res.y[:, -1], dtype=torch.float32, device=device).reshape(shape)
return x
#@title Sampling (double click to expand or collapse)
from torchvision.utils import make_grid
## Load the pre-trained checkpoint from disk.
device = 'mps' #@param ['cuda', 'cpu'] {'type':'string'}
ckpt = torch.load('ckpt_40_epochs.pth', map_location=device)
score_model.load_state_dict(ckpt)
sample_batch_size = 64 #@param {'type':'integer'}
sampler = pc_sampler #@param ['Euler_Maruyama_sampler', 'pc_sampler', 'ode_sampler'] {'type': 'raw'}
## Generate samples using the specified sampler.
samples = sampler(score_model,
marginal_prob_std_fn,
diffusion_coeff_fn,
sample_batch_size,
device=device)
## Sample visualization.
samples = samples.clamp(0.0, 1.0)
%matplotlib inline
import matplotlib.pyplot as plt
sample_grid = make_grid(samples, nrow=int(np.sqrt(sample_batch_size)))
plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
plt.show()
/var/folders/59/r58bsq3j6j9f_t7d4z2fbww40000gn/T/ipykernel_24954/1733836618.py:16: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
t = torch.tensor(t, device=device)
/var/folders/59/r58bsq3j6j9f_t7d4z2fbww40000gn/T/ipykernel_24954/1733836618.py:29: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
return torch.tensor(sigma**t, device=device)

Further Resources#
If you’re interested in learning more about score-based generative models, the following papers would be a good start:
Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. “Score-Based Generative Modeling through Stochastic Differential Equations.” International Conference on Learning Representations, 2021.
Jonathan Ho, Ajay Jain, and Pieter Abbeel. “Denoising diffusion probabilistic models.” Advances in Neural Information Processing Systems. 2020.
Yang Song, and Stefano Ermon. “Improved Techniques for Training Score-Based Generative Models.” Advances in Neural Information Processing Systems. 2020.
Yang Song, and Stefano Ermon. “Generative modeling by estimating gradients of the data distribution.” Advances in Neural Information Processing Systems. 2019.