Illustration of Diffusion Process#

I want to note that this notebook has been lightly modified from its original form, sourced from the DiffusionFastForward repository: mikonvergence/DiffusionFastForward

In this notebook, the intricacies of a denosing diffusion framework are illustrated with the aid of simple snippets.

First, let’s import an image to use for the examples.

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import imageio

mpl.rcParams['figure.figsize'] = (12, 8)

# img = torch.FloatTensor(imageio.imread('imgs/hills_2.png')/255)
img = torch.FloatTensor(imageio.imread('imgs/Klettersteig.jpg')/255)
plt.imshow(img)
/var/folders/59/r58bsq3j6j9f_t7d4z2fbww40000gn/T/ipykernel_24876/3026492335.py:11: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.
  img = torch.FloatTensor(imageio.imread('imgs/Klettersteig.jpg')/255)
<matplotlib.image.AxesImage at 0x17d764d60>
_images/0bd54d02d56181d6571ea620bb3859c43d1781125c633f7b45dd6548799f35c4.png

Data Preprocessing#

The majority of the diffusion models assume that the images are scaled to the [-1,+1] range (which tends to simplify many equations). This tutorial will follow the same approach, so we need to define input and output transformation functions input_T() and output_T().

Also, let’s define our own show() wrapper function that displays the image with automatic output transformation!

def input_T(input):
    # [0,1] -> [-1,+1]
    return 2*input-1
    
def output_T(input):
    # [-1,+1] -> [0,1]
    return (input+1)/2

def show(input):
    plt.imshow(output_T(input).clip(0,1))
    
img_=input_T(img)
show(img_)
_images/9900f498c3ef4292df5a3badf9993473b99bea9a360011f8437d1c6f0426b172.png

Defining a schedule#

The diffusion process is built based on a variance schedule, which determines the levels of added noise at each step of the process. To that end, our schedule is defined below, with the following quantities:

  • betas:\(\beta_t \in [0, 1]\)

  • alphas: \(\alpha_t=1-\beta_t\)

  • alphas_sqrt: \(\sqrt{\alpha_t}\)

  • alphas_prod: \(\bar{\alpha}_t=\prod_{i=0}^{t}\alpha_i\)

  • alphas_prod_sqrt: \(\sqrt{\bar{\alpha}_t}\)

num_timesteps=1000
betas=torch.linspace(1e-4,2e-2,num_timesteps)   # from 0.0001 to 0.02 in 1000 steps

alphas=1-betas                                  # from 0.9999 to 0.98 in 1000 steps
alphas_sqrt=alphas.sqrt()
alphas_cumprod=torch.cumprod(alphas,0)
alphas_cumprod_sqrt=alphas_cumprod.sqrt()

Forward Process#

The forward process \(q\) determines how subsequent steps in the diffusion are derived (gradual distortion of the original sample \(x_0\)).

📃 First, let’s bring up the key equations describing this process…

Basic format of the forward step: $\( q(x_t|x_{t−1}) := \mathcal{N}(x_t; \sqrt{1 − \beta_t}x_{t−1}, \beta_tI) \tag{1} \)$

For a complete trajectory \(x_{0}\) to \(x_{1:T}\) We can describe it with the following product of conditional distributions: $\( q(\mathbf{x}_{1:T} \vert \mathbf{x}_0) = \prod^T_{t=1} q(\mathbf{x}_t \vert \mathbf{x}_{t-1}) \)$

For infinite steps \(T \rightarrow \infty\) the input data would be transformed into a variable from an isotropic Gaussian distribution \(x_T \sim N(0,I)\).

DDPM

To step directly from \(x_0\) to \(x_t\) we can use the reparmetrization trick (with \(\bar{\alpha_t}=\prod^t_{i=1}\alpha_i\)): $\( q(x_t|x_0) = \mathcal{N}(x_t;\sqrt{\bar{\alpha_t}}x_0, (1 − \bar{\alpha_t})I) \tag{2} \)$

Forward Step#

Let’s define a function forward_step() that will allow us to use both \(q(x_t|x_{t-1})\) and forward_jump() for \(q(x_t|x_0)\)

def forward_step(t, condition_img, return_noise=False):
    """
        forward step: t-1 -> t
    """    
    assert t >= 0

    mean=alphas_sqrt[t]*condition_img    
    std=betas[t].sqrt()
      
    # sampling from N
    if not return_noise:
        return mean+std*torch.randn_like(img)
    else:
        noise=torch.randn_like(img)
        return mean+std*noise, noise
    
def forward_jump(t, condition_img, condition_idx=0, return_noise=False):
    """
        forward jump: 0 -> t
    """   
    assert t >= 0
    
    mean=alphas_cumprod_sqrt[t]*condition_img
    std=(1-alphas_cumprod[t]).sqrt()
      
    # sampling from N
    if not return_noise:
        return mean+std*torch.randn_like(img)
    else:
        noise=torch.randn_like(img)
        return mean+std*noise, noise
N=5 # number of computed states between x_0 and x_T
M=4 # number of samples taken from each distribution

In the first example, when t==0, we want to derive a sample \(x_t\) based on the clean sample \(x_0\)!

The first column shows the mean image for a given stage of the diffusion, and the subsequent columns to the right show several samples taken from the same distribution (they are different if you look closely!).

plt.figure(figsize=(12,8))
for idx in range(N):
    t_step=int(idx*(num_timesteps/N))
    
    plt.subplot(N,1+M,1+(M+1)*idx)
    show(alphas_cumprod_sqrt[t_step]*img_)
    plt.title(r'$\mu_t=\sqrt{\bar{\alpha}_t}x_0$') if idx==0 else None
    plt.ylabel("t: {:.2f}".format(t_step/num_timesteps))
    plt.xticks([])
    plt.yticks([])
    
    for sample in range(M):
        x_t=forward_jump(t_step,img_)
        
        plt.subplot(N,1+M,2+(1+M)*idx+sample)
        show(x_t)        
        plt.axis('off')
        
plt.tight_layout()
_images/8849df6a88f2f225f9eeec8a5c61f8097d5ca3f26cde84e344c9fff0cf70f0eb.png

Alternatively, we can test the process of going from \(x_{t-1}\) to \(x_t\), which is a single step in the diffusion process. For that we can use the forward_step function.

Note that the mean \(\mu_t\) is now a bit different (first column) since it is conditioned on a specific sample of \(x_{t-1}\)!

plt.figure(figsize=(12,8))
for idx in range(N):
    t_step=int(idx*(num_timesteps/N))
    prev_img=forward_jump(max([0,t_step-1]),img_) # directly go to prev state
    
    plt.subplot(N,1+M,1+(M+1)*idx)
    show(alphas_sqrt[t_step]*prev_img)
    plt.title(r'$\mu_t=\sqrt{1-\beta_t}x_{t-1}$') if idx==0 else None
    plt.ylabel("t: {:.2f}".format(t_step/num_timesteps))
    plt.xticks([])
    plt.yticks([])
    
    for sample in range(M):
        plt.subplot(N,1+M,2+(1+M)*idx+sample)
        x_t=forward_step(t_step,prev_img)
        show(x_t)        
        plt.axis('off')
plt.tight_layout()
_images/4691b93895e718aa6d2c53b9b79d02d655cdc99fe91e4f59137a3fd7d925bc1b.png

Reverse Process#

The purpose of the reverse process \(p\) is to approximate the previous step \(x_{t-1}\) in the diffusion chain based on a sample \(x_t\). In practice, this approximation \(p(x_{t-1}|x_t)\) must be done without the knowledge of \(x_0\).

A parametrizable prediction model with parameters \(\theta\) is used to estimate \(p_\theta(x_{t-1}|x_t)\).

The reverse process will also be (approximately) gaussian if the diffusion steps are small enough:

\[ p_\theta({x}_{0:T}) = p({x}_T) \prod^T_{t=1} p_\theta({x}_{t-1} \vert {x}_t) \quad p_\theta({x}_{t-1} \vert {x}_t) := \mathcal{N}({x}_{t-1}; \boldsymbol{\mu}_\theta({x}_t, t), \boldsymbol{\Sigma}_\theta({x}_t, t)) \tag{3} \]

In many works, it is assumed that the variance of this distribution should not depend strongly on \(x_0\) or \(x_t\), but rather on the stage of the diffusion process \(t\). This can be observed in the true distribution \(q(x_{t-1}|x_t, x_0)\), where the variance of the distribution equals \(\tilde{\beta}_t\).

Parameterizing \(\mu_\theta\)#

There are at least 3 ways of parameterizing the mean of the reverse step distribution \(p_\theta(x_{t-1}|x_t)\):

  1. Directly (a neural network will estimate \(\mu_\theta\)) $\(\mu_\theta = \mu_\theta(x_t,t)\)$

  2. Via \(x_0\) (a neural network will estimate \(x_0\)) $\(\tilde{\mu}_\theta = \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_{0, \theta}(x_t,t) + \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t \tag{4}\)$

  3. Via noise \(\epsilon\) subtraction from \(x_t\) (a neural network will estimate \(\epsilon\)) $\(x_0=\frac{1}{\sqrt{\bar{\alpha}_t}}(x_t-\sqrt{1-\bar{\alpha}_t} \epsilon_\theta(x_t,t))\tag{5}\)$

Approach 3 of approximating the normal noise \(\epsilon_\theta\) is used most widely.

Example, estimating \(\epsilon\) using a U-Net:

DDPM
U-Net Architecture with ResNet blocks (from: https://cvpr2022-tutorial-diffusion-models.github.io/)

Let’s look at what an example \(\epsilon\) might look like:

t_step=50

x_t,noise=forward_jump(t_step,img_,return_noise=True)

plt.subplot(1,3,1)
show(img_)
plt.title(r'$x_0$')
plt.axis('off')
plt.subplot(1,3,2)
show(x_t)
plt.title(r'$x_t$')
plt.axis('off')
plt.subplot(1,3,3)
show(noise)
plt.title(r'$\epsilon$')
plt.axis('off')
(-0.5, 1199.5, 1599.5, -0.5)
_images/87f46699bb06218136df3ba55190cfd34d8780a1d763eb8df718dfd50bd216db.png

If \(\epsilon\) is predicted correctly, we can use the equation (5) to predict \(x_0\):

x_0_pred=(x_t-(1-alphas_cumprod[t_step]).sqrt()*noise)/(alphas_cumprod_sqrt[t_step])

plt.subplot(1,3,1)
show(x_t)
plt.title('$x_t$ ($\ell_1$: {:.3f})'.format(F.l1_loss(x_t,img_)))
plt.axis('off')
plt.subplot(1,3,2)
show(x_0_pred)
plt.title('$x_0$ prediction ($\ell_1$: {:.3f})'.format(F.l1_loss(x_0_pred,img_)))
plt.axis('off') 
plt.subplot(1,3,3)
show(img_)
plt.title('$x_0$')
plt.axis('off')
(-0.5, 1199.5, 1599.5, -0.5)
_images/5bd66851840edbd7e880af385bd5643200cc362d9c2140534a089ea8562fcaca.png

Approximation (or knowledge) of \(x_0\) allows us to approximate the mean of the step \(t-1\), using (4).

# estimate mean
mean_pred=x_0_pred*(alphas_cumprod_sqrt[t_step-1]*betas[t_step])/(1-alphas_cumprod[t_step]) + x_t*(alphas_sqrt[t_step]*(1-alphas_cumprod[t_step-1]))/(1-alphas_cumprod[t_step])

# let's compare it to ground truth mean of the previous step (requires knowledge of x_0)
mean_gt=alphas_cumprod_sqrt[t_step-1]*img_

Since reverse process mean estimation \(\tilde{\mu}_\theta\) in (4) is effectively linear interpolation between noisy \(x_t\) and \(x_0\) it is expected to have a higher error (as the additive noise is still present) compared to the mean computed using the forward process (which is computed by scaling the clean sample by a scalar value).

plt.subplot(1,3,1)
show(x_t)
plt.title('$x_t$   ($\ell_1$: {:.3f})'.format(F.l1_loss(x_t,img_)))
plt.subplot(1,3,2)
show(mean_pred)
plt.title(r'$\tilde{\mu}_{t-1}$' + '  ($\ell_1$: {:.3f})'.format(F.l1_loss(mean_pred,img_)))
plt.subplot(1,3,3)
show(mean_gt)
plt.title(r'$\mu_{t-1}$' + '  ($\ell_1$: {:.3f})'.format(F.l1_loss(mean_gt,img_)))
Text(0.5, 1.0, '$\\mu_{t-1}$  ($\\ell_1$: 0.009)')
_images/cb52a1ed1f92d71a204b29858733ebaff02a40b979e281d747e8442cc43dc470.png

Once we get our mean_pred (\(\tilde{\mu_{t}}\)), we can define our distribution for the previous step

\[\tilde{\beta}_t=\beta_t \tag{6}\]
\[ p_\theta(x_{t-1}|x_t) := \mathcal{N}(x_{t-1};\tilde{\mu}_\theta(x_t,t),\tilde{\beta}_t I) \tag{7}\]

Important: the experiment below should be treated as a simulation. In practice, the network must predict either \(\epsilon_\theta\) or \(x^{\theta}_0\) or \(\tilde{\mu}_\theta\). Here, the value of \(epsilon\) is simply subs

def reverse_step(epsilon, x_t, t_step, return_noise=False):
    
    # estimate x_0 based on epsilon
    x_0_pred=(x_t-(1-alphas_cumprod[t_step]).sqrt()*epsilon)/(alphas_cumprod_sqrt[t_step])
    if t_step==0:
        sample=x_0_pred
        noise=torch.zeros_like(x_0_pred)
    else:
        # estimate mean
        mean_pred=x_0_pred*(alphas_cumprod_sqrt[t_step-1]*betas[t_step])/(1-alphas_cumprod[t_step]) + x_t*(alphas_sqrt[t_step]*(1-alphas_cumprod[t_step-1]))/(1-alphas_cumprod[t_step])

        # compute variance
        beta_pred=betas[t_step].sqrt() if t_step != 0 else 0

        sample=mean_pred+beta_pred*torch.randn_like(x_t)
        # this noise is only computed for simulation purposes (since x_0_pred is not known normally)
        noise=(sample-x_0_pred*alphas_cumprod_sqrt[t_step-1])/(1-alphas_cumprod[t_step-1]).sqrt()
    if return_noise:
        return sample, noise
    else:
        return sample
x_t,noise=forward_jump(1000-1,img_,return_noise=True)

state_imgs=[x_t.numpy()]
for t_step in reversed(range(1000)):
    x_t,noise=reverse_step(noise,x_t,t_step,return_noise=True)
    
    if t_step % 200 == 0:
        state_imgs.append(x_t.numpy())
plt.figure()
for idx,state_img in enumerate(state_imgs):
    plt.subplot(1,len(state_imgs),idx+1)
    show(state_img.clip(-1,1))
    plt.axis('off')
    
plt.tight_layout()
_images/c0150486569bebb24ef395c0cd3a551c909cdb5f356043401b0ba09ec1033931.png