你的浏览器不支持canvas

Welcome to my home

DDPM(去噪扩散概率模型)

Date: Author: cxb

本文系统讲解 DDPM(去噪扩散概率模型) 的核心原理与实现细节,揭示其如何通过渐进加噪与去噪学习生成高质量图像。

🧠 DDPM

[TOC]

🔄 一、核心思想:从噪声到图像

扩散模型是一种生成模型,目标是从纯高斯噪声一步步生成真实图像

它包含两个阶段:

阶段 方向 名称 做了什么
正向 $x_0 \to x_T$ Forward / Diffusion 不断加噪,让图像变成随机噪声
反向 $x_T \to x_0$ Reverse / Denoising 学习去噪,还原出原始图像

📦 二、正向过程:加噪

我们从原图 $x_0$ 出发,在每个时刻 $t$ 加入一点噪声,最终得到 $x_T$,一个近似高斯噪声的图像: \(x_t = \sqrt{\bar{\alpha}_t} \cdot x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)\)

  • $\overline{\alpha_t} = \prod_{i=1}^{t} \alpha_i$, 其中 $\alpha_{i}=1-\beta_{i}$ ( $β_i$ 是第 $i$ 步加的噪声强度), 从第 1 步到第 t 步累计保留的图像信息量
  • $t = 0 \to T$,图像越来越模糊;
  • 这个过程是可闭式计算的,无需一步步执行,一次公式就能生成 $x_t$​ ✅

🧠 三、反向过程:学习去噪

🎯 目标

从 $x_T \sim \mathcal{N}(0, I)$ 出发,逐步去噪还原出 $x_0$

🤖 学什么?

我们训练一个神经网络 $\epsilon_\theta(x_t, t)$ 来预测=预测加噪时用的 $\epsilon$。


🔁 Trick:从 $x_t$ 推出 $x_0$,再推出 $x_{t-1}$

从正向公式: \(x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon\)

可以反解出: \(x_0 = \frac{1}{\sqrt{\bar{\alpha}_t}} \left( x_t - \sqrt{1 - \bar{\alpha}_t} \cdot \epsilon_\theta(x_t, t) \right)\)

✅ 有了 $\epsilon_\theta$,就能估计 $x_0$,接着继续计算 $x_{t-1}$。


❓为什么不直接从 $x_t$ 算 $x_{t-1}$?

虽然正向有: \(x_t = f(x_{t-1}) + \text{noise}\) 但反向是概率分布,不是函数。因为:

  • 多个 $x_{t-1}$ 可能加上不同噪声后变成同一个 $x_t$
  • 所以 $q(x_{t-1}|x_t)$​ 是个复杂分布,没法显式表示

我们没有办法得到 $q(x_{t-1})$或 $q(x_t)$的明确表达式,因为它们涉及从 $x_0$ 积分过来的所有路径: \(q(x_t) = \int q(x_t \mid x_{t-1}) q(x_{t-1}) dx_{t-1}\) 但 $q(x_{t-1})$并不是一个简单的分布!因为它本身是从一系列有噪声扰动的步骤中一步步卷积出来的复杂分布,然后 $q(x_{t-2})$还要再由 $q(x_{t-3})$ 推来……最终都依赖于 $q(x_0)$,也就是原始数据分布。但!💥 我们根本不知道 $q(x_0)$ 是什么!

因此,我们只能通过$x_0$估计它的均值和方差。


🔁 反向采样公式估计,引入$x_0$,用可解的 $q(x_{t-1} \mid x_t, x_0)$ 来估计$x_{t-1}$

利用贝叶斯公式: \(p\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_{t}, x_{0}\right)=\frac{p\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{t-1}\right) p\left(\boldsymbol{x}_{t-1} \mid x_{0}\right)}{p\left(\boldsymbol{x}_{t} \mid \boldsymbol{x}_{0}\right)}\) 式子中每一项都是可解的高斯分布,所以我们可以用条件高斯乘积公式,得到: \(p_\theta(x_{t-1} | x_t) = \mathcal{N}(\mu_\theta(x_t, t), \sigma_t^2 I)\) 其中:

均值: \(\mu_\theta = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \cdot \epsilon_\theta(x_t, t) \right)\) 也可以写成基于 x₀ 的形式: \(\mu\left(x_{t}, x_{0}\right)=\frac{\sqrt{\bar{\alpha}_{t-1}} \cdot \beta_{t}}{1-\bar{\alpha}_{t}} x_{0}+\frac{\sqrt{1-\beta_{t}} \cdot\left(1-\bar{\alpha}_{t-1}\right)}{1-\bar{\alpha}_{t}} x_{t}\)

方差

  • 选择:方差$\sigma_t^2$设为常数(不训练),实验发现两种选择效果相似:
    • $\sigma_t^2 = \beta_t$(对应数据初始为高斯分布)
    • $\sigma_t^2 = \frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_t} \cdot \beta_t$(对应数据初始为单点分布)

🤔 常见问题解析

Q1:为啥不直接训练一个学 $x_{t-1}$的模型?

  • 空间太大,不容易收敛;
  • 没法直接监督 $x_{t-1}$,但能监督 $\epsilon$

Q2:为什么不直接用预测的 $x_0$ 当作最终的生成结果?

  • 预测的 $x_0$ 是近似值;
  • 多个 $x_t$ 推出的 $x_0$ 不一致;
  • 扩散模型本质是一步步净化,不能一步到位。

Q3:为什么用 UNet 预测噪声 ϵ,而不是直接预测真实反向均值?

  • 因为噪声 ϵ 的分布固定,预测更容易,训练更稳定;
  • 通过数学推导(式10),发现可以改写为预测噪声ϵ的形式,计算更简单

🏋️‍♀️ 模型训练

训练时优化: \(\mathcal{L}_{\text{simple}} = \mathbb{E}_{x_0, t, \epsilon} \left[ \left\| \epsilon - \epsilon_\theta(x_t, t) \right\|^2 \right]\)

流程如下:

  1. 从真实图像 $x_0$ 采样 $t$,加噪得 $x_t$
  2. 用网络预测噪声 $\epsilon_\theta(x_t, t)$
  3. 用MSE计算损失,与真实 $\epsilon$ 对比

🎨 推理 / 采样阶段

从高斯噪声 $x_T$ 开始,按如下公式逐步采样直到 $x_0$: \(x_{t-1} = \mu_\theta(x_t, t) + \sigma_t \cdot z, \quad z \sim \mathcal{N}(0, I)\) 每步逻辑:

  • 先预测噪声 $\epsilon_\theta$
  • 再估计 $x_0$,推导 $\mu_\theta$
  • 加入随机噪声 $z$ 得到 $x_{t-1}$
  • 不断重复,最终得到生成图像!

🧩 Diffusion 模型模板代码

参考pytorch代码:GitHub - chunyu-li/ddpm: 扩散模型的简易 PyTorch 实现

1. 初始化和必要的导入

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import numpy as np
import random

2. 定义 Beta Schedule 和相关函数

# 线性 beta schedule(控制每步噪声的大小)
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

# 获取累积的 alpha 值
def get_alphas(betas):
    return 1.0 - betas

# 获取累积 alpha 的乘积
def get_alphas_cumprod(alphas):
    return torch.cumprod(alphas, axis=0)

# 计算反向噪声的标准差
def get_posterior_variance(alphas_cumprod, alphas_cumprod_prev, betas):
    return betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

3. 定义 Diffusion 模型(UNet)

class SimpleUnet(nn.Module):
    def __init__(self):
        super(SimpleUnet, self).__init__()
        # 这里定义一个简单的卷积网络作为示例,可以替换成更复杂的UNet
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(128 * 64 * 64, 256)
        self.fc2 = nn.Linear(256, 3 * 64 * 64)

    def forward(self, x, t):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # Flatten for fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x.view(x.size(0), 3, 64, 64)  # Reshape back to image shape

4. 扩散过程和去噪过程

正向扩散过程:从 $x_0$ 到 $x_t$

def forward_diffusion_sample(x_0, t, betas, alphas_cumprod):
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = alphas_cumprod[t].view(-1, 1, 1, 1)  # 广播至批次维度
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1.0 - alphas_cumprod[t]).view(-1, 1, 1, 1)
    
    x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
    return x_t, noise

反向扩散过程:从 $x_T$ 到 $x_0$

def sample_timestep(x_t, t, model, alphas_cumprod, betas):
    # 模型预测噪声
    epsilon_pred = model(x_t, t)

    # 计算当前时刻的均值和标准差
    sqrt_alphas_cumprod_t = alphas_cumprod[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1.0 - alphas_cumprod[t]).view(-1, 1, 1, 1)
    posterior_variance_t = betas[t].view(-1, 1, 1, 1)

    # 计算预测的x_0
    x_0_pred = (x_t - sqrt_one_minus_alphas_cumprod_t * epsilon_pred) / sqrt_alphas_cumprod_t

    # 反向采样
    noise = torch.randn_like(x_t) if t > 0 else torch.zeros_like(x_t)
    x_t_minus_1 = sqrt_alphas_cumprod[t - 1] * x_0_pred + sqrt_one_minus_alphas_cumprod[t - 1] * noise
    return x_t_minus_1

5. 损失函数(训练时)

def get_loss(model, x_0, t, betas, alphas_cumprod):
    x_t, noise = forward_diffusion_sample(x_0, t, betas, alphas_cumprod)
    noise_pred = model(x_t, t)
    return F.mse_loss(noise, noise_pred)  # MSE 损失

6. 数据加载和预处理

def load_transformed_dataset(img_size=64, batch_size=128):
    data_transforms = [
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1),  # [0,1] -> [-1,1]
    ]
    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.ImageFolder(root="./stanford_cars/cars_train", transform=data_transform)
    test = torchvision.datasets.ImageFolder(root="./stanford_cars/cars_test", transform=data_transform)

    dataset = torch.utils.data.ConcatDataset([train, test])
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

7. 训练循环

if __name__ == "__main__":
    # 初始化
    model = SimpleUnet()
    T = 300  # 扩散步数
    betas = linear_beta_schedule(T)
    alphas = get_alphas(betas)
    alphas_cumprod = get_alphas_cumprod(alphas)

    BATCH_SIZE = 128
    epochs = 100

    dataloader = load_transformed_dataset(batch_size=BATCH_SIZE)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        for batch_idx, (batch, _) in enumerate(dataloader):
            optimizer.zero_grad()
            batch = batch.to(device)

            t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
            loss = get_loss(model, batch, t, betas, alphas_cumprod)
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0:
                print(f"Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item()}")

8. 采样(生成图像)

def generate_samples(model, T=300):
    # 从噪声开始
    x_t = torch.randn((BATCH_SIZE, 3, 64, 64)).to(device)
    
    for t in reversed(range(T)):
        x_t = sample_timestep(x_t, t, model, alphas_cumprod, betas)
    
    return x_t

9. 显示图像

def show_tensor_image(image):
    image = image.squeeze().cpu().numpy().transpose(1, 2, 0)
    image = (image + 1.0) / 2.0  # [-1, 1] -> [0, 1]
    plt.imshow(image)
    plt.axis('off')
    plt.show()