PyTorch深度学习实战 |手算​​U-net

张开发
2026/4/10 14:57:18 15 分钟阅读

分享文章

PyTorch深度学习实战 |手算​​U-net
欢迎来到PyTorch深度学习实战的世界博客主页卿云阁欢迎关注点赞收藏⭐️留言首发时间2026年4月8日✉️希望可以和大家一起完成进阶之路作者水平很有限如果发现错误请留言轰炸哦万分感谢目录U-Net 与 FCN 的本质区别U-net核心剖析网络结构卷积模式跳跃连接弹性形变推理过程损失函数的计算代码的实现U-Net 与 FCN 的本质区别FCN 的跳跃连接是“相加Add”就像是把两张半透明的图纸叠在一起看颜色特征值融合了但图纸的厚度通道数没变。U-Net 的跳跃连接是“拼接Concat”就像是把一本浅层特征的书和一本深层特征的书直接装订在一起。厚度通道数直接翻倍网络保留了所有的原始信息让后面的卷积层自己去挑想要看哪一页。维度相加 (Add)拼接 (Concat)数学操作X_1 X_2 (对应元素相加)[X_1, X_2] (通道维度扩充)通道数变化不变 (C)增加 (C_1 C_2)信息保留程度破坏性混合(难以追溯来源)无损保留(各自独立)计算复杂度极低(对后续卷积无压力)很高(后续卷积参数激增)梯度回传1:1 无损复制回传根据后续卷积的权重分配回传代表网络FCN, ResNet, FPNU-Net, DenseNet, YOLO (PANet)相加 (Add) —— FCN极致的计算与内存效率因为是逐像素对应相加融合后的特征图通道数 C 保持不变。这意味着紧跟在它后面的卷积层的参数量和计算量完全不会增加。极其省显存天生的梯度高速公路相加操作在反向传播时它的导数是 1。这意味着梯度可以等价地、无损地同时传递给浅层和深层网络非常有利于缓解梯度消失问题这就是 ResNet 为什么能训到 100 层的根本原因。信息的破坏性混合“颜色融合了”意味着网络再也无法区分“哪部分信息来自浅层哪部分来自深层”。原始的、高分辨率的空间细节比如极细的边缘可能会在相加的过程中被深层的抽象语义给“中和”甚至淹没掉。拼接 (Concat) —— U-Net 优点 (Pros)100% 的信息无损保留这是它在医学图像分割U-Net 老本行中封神的原因。浅层特征包含极其清晰的边界、纹理信息被完好无损地打包送到了后面。不管经历了多少次下采样那些可能被丢弃的细微空间信息全都在。特征选择权交给了网络我们不去强行干预信息怎么混合而是把变厚的“书”直接扔给后面的卷积层。后接的 3* 3或 1 * 1 卷积拥有自己的权重矩阵网络会通过训练自己决定“这部分我要多看浅层的细节那部分我要多看深层的语义”。U-net核心剖析网络结构 第一阶段左侧编码器下坡提取特征输入图像是一张单通道图尺寸为 [1, 572, 572]。第一层 (Stage 1)经过第一次 3* 3 卷积572 - 2 570。维度变为 [64, 570, 570]。经过第二次 3 * 3 卷积570 - 2 568。维度变为 [64, 568, 568]。经过 2*2最大池化尺寸除以 2568 / 2 284。下沉至下一层 [64, 284, 284]。第二层 (Stage 2)通道翻倍。两次卷积284 - 2 282282 - 2 280。此时输出 [128, 280, 280]。池化后尺寸减半280 / 2 140。下沉至 [128, 140, 140]。第三层 (Stage 3)通道翻倍。两次卷积140 - 2 138138 - 2 136。此时输出 [256, 136, 136]。池化后尺寸减半136 / 2 68。下沉至 [256, 68, 68]。第四层 (Stage 4)通道翻倍。两次卷积68 - 2 6666 - 2 64。此时输出 [512, 64, 64]。池化后尺寸减半64 / 2 32。下沉至谷底 [512, 32, 32]。⚓ 第二阶段谷底瓶颈层最抽象的语义此时来到了 U 型结构的最底部这里不进行池化只做两次卷积。通道翻倍至极限。两次卷积32 - 2 3030 - 2 28。谷底输出[1024, 28, 28]。这张 28* 28的图包含了全图最全局的上下文信息。 第三阶段右侧解码器上坡与残酷的裁剪这里就是 U-Net 最神奇的“Copy and Crop (复制与裁剪)”环节。因为之前卷积导致图像不断缩小左右两侧的特征图尺寸是对不上的第一步上坡 (Up 1)将谷底的 [1024, 28, 28] 经过 2 *2 上采样反卷积。尺寸翻倍通道减半变为 [512, 56, 56]。【裁剪与拼接】 回顾左侧第四层尺寸是 64 *64。为了和当前的 56 * 56拼接必须把左侧特征图周围多余的边界“一刀切掉”使其中心保留 56* 56的区域。在通道维度拼接后通道数翻倍[512 512, 56, 56] to[1024, 56, 56]。两次卷积56 - 2 5454 - 2 52。本层最终输出 [512, 52, 52]。第二步上坡 (Up 2)上采样[512, 52, 52] to [256, 104, 104]。提取左侧第三层136 * 136裁剪至 104 * 104。拼接后通道翻倍至 512。两次卷积104 - 2 102102 - 2 100。输出 [256, 100, 100]。第三步上坡 (Up 3)上采样[256, 100, 100] to [128, 200, 200]提取左侧第二层280*280裁剪至 200 * 200。拼接后通道翻倍至 256。两次卷积200 - 2 198198 - 2 196。输出 [128, 196, 196]。第四步上坡 (Up 4)上采样[128, 196, 196] to[64, 392, 392]提取左侧第一层最原始568 *568大量裁剪至 392 * 392。拼接后通道翻倍至 128。两次卷积392 - 2 390390 - 2 388。输出 [64, 388, 388]。 第四阶段最终分类输出拿到了 [64, 388, 388] 的超高分辨率特征图后原作者使用了 1 个 1 * 1 卷积层。1 * 1 卷积不会改变图像的尺寸只负责把 64 个通道压缩到类别数比如前景和背景 2 个类别。最终完美降落[2, 388, 388]。总结一下编码器 (Encoder - 提取特征)通过连续的“卷积 2*2 最大池化”逐步提取深层语义每次池化使特征图尺寸严格减半。解码器 (Decoder - 恢复细节)通过 2*2 反卷积将尺寸翻倍。随后将它与左侧对应的浅层特征进行拼接 (Concat) 并再次卷积融合。核心动作 (Crop - 裁剪)由于全程采用无填充卷积Valid Padding每次卷积尺寸减 2左右两端特征图尺寸并不匹配。在拼接前必须对左侧较大的特征图进行中心裁剪 (Crop) 以强制对齐。输出层 (Output)最后由 1*1 卷积将通道数降为类别数如 2 维。输入572*572 的图像经过层层“缩水”最终输出为 388*388。(悬念为什么语义分割允许输入输出尺寸不一致我们后文揭晓)卷积模式U-Net 编码器中特征图缩小正是因为使用了 Valid 模式。这三种卷积模式的核心区别如下Full 模式全相交只要卷积核的边缘“碰”到图像就开始计算。补白最多输出尺寸变大。Same 模式等大小卷积核中心对齐图像边缘开始计算。适当补白保证输出尺寸与输入保持一致工程中最常用。Valid 模式无填充卷积核必须 100% 完全包含在图像内部才计算。绝不补白输出尺寸变小U-Net 原版采用的模式。跳跃连接随着卷积网络的加深图像的纹理、边缘等浅层细节会被不可逆地压缩甚至丢失。为了在最终的分割图里找回这些边界我们引入了跳跃连接Skip Connection——将低级语义信息“抄近道”直接补偿给高级语义特征。在 FCN 中我们使用了 Add而在 U-Net 中则使用了 Concat。操作方式机制本质核心优势致命劣势Add(FCN)对应像素相加通道数不变省算力计算量小不增加后续卷积的参数负担。融特征相当于残差学习。破坏独立性浅层细节可能被深层响应淹没小尺度特征容易丢失。Concat(U-Net)通道维度拼接通道数增加保细节无损保留全层级信息让网络自己学习如何挑选特征提取能力极强。吃显存拼接后通道翻倍导致后续卷积的计算量和参数量呈平方级暴涨。Overlap-tile 策略医学图像通常极其庞大显存根本无法一次性塞下整张图必须用“滑动窗口”切成小块Patch来处理。这就带来了图像分割的一个核心矛盾切片边缘的像素失去了周围的上下文信息会导致边缘预测极度不准。为了解决这个问题U-Net 作者没有用传统的 Padding 补零而是极其优雅地提出了 Overlap-tile (重叠切片) 策略。其核心逻辑可以概括为三点拿大图测小图解决上下文缺失这正是输入 572*572输出却是 388*388 的根本原因模型每次吃进一张 572*572的大视野图像利用外围丰富的上下文特征去极致且严谨地预测它正中心那块 388*388的核心区域。核心拼图解决重叠冗余滑动窗口在移动时大视野输入框会互相重叠但我们只取模型输出的 388*388 核心预测结果。这些核心结果就像贴瓷砖Tile一样严丝合缝地拼成完整的全图分割结果。镜像填充 Mirror Extrapolation解决绝对物理边界当切片滑动到原图的最边缘外面没有像素了怎么办作者采用“镜像翻折”的方法把原图内部的像素像照镜子一样翻折到外面伪造上下文。这样原图最边缘的图像也能得到完美的特征预测。弹性形变因为细胞和生物组织本就是柔软的具有天然的不规则扭曲特性。普通的旋转和平移太死板了而弹性变换能够完美模拟真实细胞在切片时的自然形变与结构畸变从而极大提升了模型的泛化能力。生成扰动给原图上的每一个像素坐标 (x, y)都生成一个 [-1, 1] 区间内的纯随机偏移量。平滑与放大如果直接加随机数图片会变成全是马赛克的噪点图。所以我们要用高斯滤波把这些杂乱的偏移量“抹平滑”再乘上一个缩放系数Alpha来放大扭曲的力度得到最终的偏移量(delta_x, delta_y)。映射像素将原图上的像素生拉硬拽到新的坐标 (xdelta_x, ydelta_y) 处中间产生的空隙通过插值算法补齐。推理过程网络的输入 (Input) 细胞分割任务中通常是单张灰度切片所以张量为[1, 1, 572, 572]。网络的输出是[Num_Classes, 388, 388]。如果只是区分“细胞”和“背景”通道数就是 2。医院拍出的一张病理切片分辨率可能高达 10,000*10,000 像素。显卡根本不可能一次性吞下这么大的图。如果强行把图 resize 缩放到 572细胞全糊了根本无法诊断。所以U-Net 在推理时的核心思想是“原分辨率切片 逐块预测 无缝拼接”。Step 1: 网格化目标区 (Grid the Output Tiles)首先算法会在那张 10,000* 10,000的超大原图上画网格。每一个网格的尺寸严格等于网络的输出大小即 388 * 388。这些网格就是我们最终要像贴瓷砖一样贴满全图的目标块它们之间绝不重叠。Step 2: 截取带上下文的输入框 (Extract Context Crop)为了预测网格中某一块 388 *388 的瓷砖程序会以这块瓷砖为中心向外扩张在原图上截取一个 572 * 572 的大视野图。特殊情况处理如果这个 572 的大框超出了原图的物理边界算法就会立刻启动“镜像翻折Mirroring”把内部的图像翻折到外面补齐。Step 3: 前向传播 (Forward Pass)将这块准备好的 572* 572大图送入训练好的 U-Net 模型。网络内部经过特征提取、跳跃连接Concat、一刀刀的裁剪Crop最终吐出一张精准的 388*388 预测结果。Step 4: 无缝拼图 (Seamless Stitching)拿到这 388* 388 的预测结果后把它放到刚才 Step 1 划分好的对应网格位置上。接着移动大视野框去处理下一块瓷砖。周而复始直到所有的 388* 388小块严丝合缝地填满那张10,000 * 10,000 的大画布。现在的Unet版本在现代版本的 U-Net 中最核心的特征就是输入图片大小 输出图片大小比如你输入一张 512x512 的图片网络最后输出的预测结果Mask也完完全全是 512x512。之所以能做到这么“所见即所得”是因为现代 U-Net 抛弃了原版那个会吃掉边缘的 ValidPadding全面采用了 Same Padding在代码里通常是设置 padding1。这样每次卷积后图像尺寸不会缩水跳跃连接Concat时左右两边大小完美对齐再也不需要做任何裁剪Crop了。损失函数的计算带空间权重的像素级交叉熵损失 (Pixel-wise Weighted Cross-Entropy Loss)权重矩阵是怎么得到的呐权重矩阵是通过测量背景中每个像素点“距离最近的两个细胞有多远”来自动计算的。程序会计算某个背景像素到离它最近的细胞边缘的距离d_1以及到第二近的细胞边缘的距离d_2。如果这个像素刚好夹在两个快要粘连的细胞中间那么 d_1 和 d_2都会非常小数学公式就会像触发警报一样把这两个微小的距离代入指数函数瞬间计算出一个极高的“惩罚倍数”赋予该像素而对于那些远离所有细胞的普通空旷背景算出来的权重就只是普通的 1。这就相当于用数学方法精准锁定了所有细胞间的“狭窄缝隙”逼迫网络在训练时绝不能把它们连在一起。代码的实现第一部分双卷积基础模块 Down_Up_Convimport torch import torch.nn as nn class Down_Up_Conv(nn.Module): U-Net 的基础组件连续两次 3x3 卷积 包含Conv2d - BatchNorm2d - ReLU - Conv2d - BatchNorm2d - ReLU def __init__(self, in_channels, out_channels, kernel_size3, stride1, padding1): super(Down_Up_Conv, self).__init__() self.conv_block nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_sizekernel_size, stridestride, paddingpadding, biasFalse), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue), nn.Conv2d(out_channels, out_channels, kernel_sizekernel_size, stridestride, paddingpadding, biasFalse), nn.BatchNorm2d(out_channels), nn.ReLU(inplaceTrue) ) def forward(self, x): return self.conv_block(x)第二部分跳跃连接与裁剪拼接 crop_and_concatdef crop_and_concat(upsampled, bypass): 将两个 feature map 在 H 和 W 上对齐后拼接dim1 - upsampled: 解码器上采样后的特征图 (N, C1, H1, W1) - bypass: 编码器传来的特征图 (N, C2, H2, W2) h1, w1 upsampled.shape[2], upsampled.shape[3] h2, w2 bypass.shape[2], bypass.shape[3] # 计算尺寸差值 delta_h h2 - h1 delta_w w2 - w1 # 对 encoder 输出 (bypass) 进行中心裁剪使其和 upsampled 一样大 bypass_cropped bypass[:, :, delta_h // 2: delta_h // 2 h1, delta_w // 2: delta_w // 2 w1] # 沿通道维(dim1)进行拼接 return torch.cat([bypass_cropped, upsampled], dim1)第三部分完整的 UNet 模型class UNet(nn.Module): def __init__(self, num_classes2): super(UNet, self).__init__() # 编码器 Encoder self.stage_down1 Down_Up_Conv(3, 64) self.stage_down2 Down_Up_Conv(64, 128) self.stage_down3 Down_Up_Conv(128, 256) self.stage_down4 Down_Up_Conv(256, 512) # 瓶颈层 Bottleneck self.stage_down5 Down_Up_Conv(512, 1024) # 上采样层 # 优化采用 kernel_size2, stride2 是最标准且高效的尺寸翻倍方式 self.up4 nn.ConvTranspose2d(1024, 512, kernel_size2, stride2) self.up3 nn.ConvTranspose2d(512, 256, kernel_size2, stride2) self.up2 nn.ConvTranspose2d(256, 128, kernel_size2, stride2) self.up1 nn.ConvTranspose2d(128, 64, kernel_size2, stride2) # 解码器 Decoder self.stage_up4 Down_Up_Conv(1024, 512) # 拼接后通道为 5125121024卷积后降回 512 self.stage_up3 Down_Up_Conv(512, 256) self.stage_up2 Down_Up_Conv(256, 128) self.stage_up1 Down_Up_Conv(128, 64) # 输出层 Output (修正 Bug 处) # 绝不能用带 ReLU 和 BN 的卷积必须使用单个 1x1 卷积输出原始 Logits self.outc nn.Conv2d(64, num_classes, kernel_size1) self.maxpool nn.MaxPool2d(kernel_size2, stride2) self.initialize_weights() def initialize_weights(self): Kaiming 权重初始化非常专业的习惯 for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): # --- Encoder --- stage1 self.stage_down1(x) # [N, 64, H, W] x self.maxpool(stage1) stage2 self.stage_down2(x) # [N, 128, H/2, W/2] x self.maxpool(stage2) stage3 self.stage_down3(x) # [N, 256, H/4, W/4] x self.maxpool(stage3) stage4 self.stage_down4(x) # [N, 512, H/8, W/8] x self.maxpool(stage4) # --- Bottleneck --- stage5 self.stage_down5(x) # [N, 1024, H/16, W/16] # --- Decoder --- x self.up4(stage5) # [N, 512, H/8, W/8] x crop_and_concat(x, stage4) # 拼接后: [N, 1024, H/8, W/8] x self.stage_up4(x) # [N, 512, H/8, W/8] x self.up3(x) x crop_and_concat(x, stage3) x self.stage_up3(x) x self.up2(x) x crop_and_concat(x, stage2) x self.stage_up2(x) x self.up1(x) x crop_and_concat(x, stage1) x self.stage_up1(x) # --- Output --- out self.outc(x) # [N, num_classes, H, W] return out

更多文章