扩散模型实战:从零开始用PyTorch搭建你的第一个图像生成器(附完整代码)

张开发
2026/5/22 7:09:07 15 分钟阅读
扩散模型实战:从零开始用PyTorch搭建你的第一个图像生成器(附完整代码)
扩散模型实战从零开始用PyTorch搭建你的第一个图像生成器当你在社交媒体上看到那些由AI生成的逼真头像时是否好奇它们是如何被创造出来的扩散模型作为当前最先进的生成技术正在重塑我们创造和想象图像的方式。本文将带你从零开始用PyTorch实现一个能够生成MNIST手写数字的基础扩散模型。1. 扩散模型基础概念扩散模型的核心思想是通过逐步添加噪声破坏数据再学习如何逆转这个过程。想象一杯清水滴入墨水的过程——扩散模型正是模拟这种从有序到无序再从无序重建有序的逆向工程。关键组件解析正向过程将数据逐渐转化为高斯噪声的马尔可夫链反向过程通过神经网络学习从噪声中重建数据的去噪步骤噪声调度控制每个时间步添加的噪声量# 典型噪声调度器实现 def linear_beta_schedule(timesteps): beta_start 0.0001 beta_end 0.02 return torch.linspace(beta_start, beta_end, timesteps)提示DDPM(去噪扩散概率模型)通常使用1000个时间步这是噪声添加和去除的迭代次数2. 环境准备与数据加载在开始构建模型前我们需要配置开发环境并准备数据集。建议使用Python 3.8和PyTorch 1.12版本。环境依赖PyTorch with CUDA支持torchvisionmatplotlib用于可视化tqdm进度条显示pip install torch torchvision matplotlib tqdmMNIST数据集加载与预处理from torchvision import datasets, transforms transform transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_dataset datasets.MNIST( root./data, trainTrue, downloadTrue, transformtransform ) train_loader torch.utils.data.DataLoader( train_dataset, batch_size128, shuffleTrue )3. 构建UNet噪声预测器UNet架构因其编码器-解码器结构特别适合扩散模型任务。我们的实现将包含下采样块编码器中间块上采样块解码器时间步嵌入关键实现细节class Block(nn.Module): def __init__(self, in_ch, out_ch, time_emb_dim): super().__init__() self.time_mlp nn.Linear(time_emb_dim, out_ch) self.conv1 nn.Conv2d(in_ch, out_ch, 3, padding1) self.conv2 nn.Conv2d(out_ch, out_ch, 3, padding1) def forward(self, x, t): h self.conv1(x) time_emb self.time_mlp(t)[:, :, None, None] h h time_emb h F.relu(h) return self.conv2(h)时间步嵌入采用Transformer中的正弦位置编码def timestep_embedding(timesteps, dim): half_dim dim // 2 emb math.log(10000) / (half_dim - 1) emb torch.exp(torch.arange(half_dim, devicedevice) * -emb) emb timesteps[:, None] * emb[None, :] emb torch.cat((emb.sin(), emb.cos()), dim-1) return emb4. 训练流程实现扩散模型的训练过程需要精心设计噪声添加和损失计算策略。以下是关键训练步骤随机采样时间步计算对应噪声预测噪声并计算损失反向传播更新参数训练循环核心代码def train_loop(model, loader, optimizer, device): model.train() for batch_idx, (data, _) in enumerate(loader): data data.to(device) # 随机采样时间步 t torch.randint(0, timesteps, (data.shape[0],), devicedevice).long() # 生成噪声 noise torch.randn_like(data) # 添加噪声 x_t q_sample(data, t, noise) # 预测噪声 predicted_noise model(x_t, t) # 计算损失 loss F.mse_loss(predicted_noise, noise) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step()注意使用Adam优化器时学习率通常设置为1e-4到2e-4之间5. 采样与图像生成训练完成后我们可以通过逐步去噪从随机噪声生成新图像。采样过程是训练反向过程的实现torch.no_grad() def p_sample(model, x, t, t_index): betas_t extract(betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t extract( sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_recip_alphas_t extract(sqrt_recip_alphas, t, x.shape) # 预测噪声 pred_noise model(x, t) # 计算均值 model_mean sqrt_recip_alphas_t * ( x - betas_t * pred_noise / sqrt_one_minus_alphas_cumprod_t ) if t_index 0: return model_mean else: posterior_variance_t extract(posterior_variance, t, x.shape) noise torch.randn_like(x) return model_mean torch.sqrt(posterior_variance_t) * noise完整采样流程torch.no_grad() def p_sample_loop(model, shape): device next(model.parameters()).device # 从纯噪声开始 img torch.randn(shape, devicedevice) imgs [] for i in tqdm(reversed(range(0, timesteps)), desc采样循环): img p_sample(model, img, torch.full((shape[0],), i, devicedevice, dtypetorch.long), i) imgs.append(img.cpu().numpy()) return imgs6. 高级技巧与优化要让扩散模型达到最佳性能还需要考虑以下优化策略显存优化技术混合精度训练梯度检查点分布式数据并行训练稳定性提升EMA指数移动平均模型权重学习率预热梯度裁剪# EMA实现示例 class EMA: def __init__(self, beta): super().__init__() self.beta beta self.step 0 def update_model_average(self, ema_model, current_model): for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()): old_weight, new_weight ema_params.data, current_params.data ema_params.data self.update_average(old_weight, new_weight) def update_average(self, old, new): if old is None: return new return old * self.beta (1 - self.beta) * new7. 结果评估与可视化评估生成模型质量常用FID(Fréchet Inception Distance)指标但对于MNIST这类简单数据集我们可以直接观察生成样本。生成样本可视化def plot_images(images): plt.figure(figsize(10, 10)) plt.imshow(torch.cat([ torch.cat([i for i in images.cpu()], dim-1), ], dim-2).permute(1, 2, 0).cpu()) plt.show()训练过程中可以定期保存检查点并生成样本if epoch % 10 0: sampled_images sample(model, image_sizeimage_size, batch_size64) plot_images(sampled_images[-1]) torch.save(model.state_dict(), fddpm_model_{epoch}.pth)在实际项目中我发现调整噪声调度策略对生成质量影响显著。线性调度简单但效果不错而余弦调度通常能产生更自然的过渡。另一个关键点是UNet中残差连接的设计——它们能有效缓解深层网络的梯度消失问题。

更多文章