Diffusion Model 3:DDPM 逆扩散过程推导
扩散模型(Diffusion Models)是近年来生成式建模中的重要技术,具有生成质量高、灵活性强的特点。DDPM(Denoising Diffusion Probabilistic Model)是扩散模型的经典代表,其核心思想是通过逐步添加噪声构造一个易于建模的分布,然后反向去噪生成高质量样本。
本文聚焦DDPM的逆扩散过程,从原理推导到代码实现,结合图解帮助你轻松掌握这一重要技术。
1. 什么是扩散模型?
扩散模型基于两个过程:
- 正向扩散(Forward Diffusion):从真实数据分布开始,通过逐步添加高斯噪声将其变换为标准正态分布。
- 逆向扩散(Reverse Diffusion):从标准正态分布出发,逐步去噪还原到数据分布。
2. DDPM的正向扩散过程
数学定义
正向扩散从真实数据 ( x_0 ) 开始,定义一系列中间状态 ( x_1, x_2, \dots, x_T ),满足以下条件:
其中:
- ( \alpha_t \in (0, 1) ) 是控制噪声强度的参数。
正向过程的多步表示为:
其中 ( \bar{\alpha}_t = \prod_{s=1}^t \alpha_s )。
3. 逆扩散过程推导
3.1 目标分布
逆扩散的目标是学习条件分布:
我们假设其形式为高斯分布:
3.2 参数化过程
为了简化建模,通常假设 ( \Sigma_\theta(x_t, t) ) 是对角矩阵或常数,重点放在学习 ( \mu_\theta(x_t, t) )。通过变分推导可以得到:
其中:
- ( \epsilon_\theta(x_t, t) ) 是用于预测噪声的神经网络。
4. DDPM逆扩散过程实现
以下是用PyTorch实现DDPM的核心模块,包括正向扩散和逆向生成。
4.1 正向扩散过程
import torch
import torch.nn as nn
import numpy as np
class DDPM(nn.Module):
def __init__(self, beta_start=1e-4, beta_end=0.02, timesteps=1000):
super(DDPM, self).__init__()
self.timesteps = timesteps
self.betas = torch.linspace(beta_start, beta_end, timesteps) # 噪声调度参数
self.alphas = 1 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0) # 累积乘积
def forward_diffusion(self, x0, t):
"""正向扩散过程: q(x_t | x_0)"""
sqrt_alpha_bar_t = torch.sqrt(self.alpha_bars[t]).unsqueeze(1)
sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - self.alpha_bars[t]).unsqueeze(1)
noise = torch.randn_like(x0)
xt = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * noise
return xt, noise
# 示例:正向扩散
timesteps = 1000
ddpm = DDPM(timesteps=timesteps)
x0 = torch.randn(16, 3, 32, 32) # 假设输入图片
t = torch.randint(0, timesteps, (16,))
xt, noise = ddpm.forward_diffusion(x0, t)
4.2 逆扩散过程
逆扩散过程依赖一个噪声预测网络 ( \epsilon_\theta ),通常使用U-Net实现。
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=3, hidden_channels=64):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1)
)
self.decoder = nn.Sequential(
nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1)
)
def forward(self, x):
return self.decoder(self.encoder(x))
# 逆扩散实现
def reverse_diffusion(ddpm, unet, xt, timesteps):
for t in reversed(range(timesteps)):
t_tensor = torch.full((xt.size(0),), t, device=xt.device, dtype=torch.long)
alpha_t = ddpm.alphas[t].unsqueeze(0).to(xt.device)
alpha_bar_t = ddpm.alpha_bars[t].unsqueeze(0).to(xt.device)
sqrt_recip_alpha_t = torch.sqrt(1.0 / alpha_t)
sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - alpha_bar_t)
pred_noise = unet(xt)
xt = sqrt_recip_alpha_t * (xt - sqrt_one_minus_alpha_bar_t * pred_noise)
return xt
# 示例:逆扩散
unet = UNet()
xt_gen = reverse_diffusion(ddpm, unet, xt, timesteps)
5. 图解DDPM逆扩散
正向扩散过程
- 数据逐步添加噪声,逐渐接近标准正态分布。
公式图示:
- ( x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon )
逆扩散过程
- 从随机噪声开始,通过逐步去噪恢复数据。
公式图示:
- ( x_{t-1} = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta) )
6. 总结
本文从原理推导出发,详细解析了DDPM的逆扩散过程,结合代码示例和图解,帮助你理解扩散模型的核心思想。扩散模型正在快速成为生成式AI的关键技术,DDPM为实现高质量图像生成提供了一个强大的框架。未来,可以通过改进噪声调度或引入更多条件控制(如文本或标签)进一步增强其能力。