保姆级教程:用PyTorch从零搭建SegFormer语义分割模型(附B0主干网络数据流图解)

张开发
2026/4/12 23:08:43 15 分钟阅读

分享文章

保姆级教程:用PyTorch从零搭建SegFormer语义分割模型(附B0主干网络数据流图解)
从零构建SegFormer语义分割模型PyTorch实战与数据流全解析语义分割作为计算机视觉领域的核心技术之一在自动驾驶、医疗影像分析等领域有着广泛应用。而SegFormer凭借其独特的Transformer与CNN混合架构在精度和效率上取得了显著突破。本文将带您从零开始用PyTorch完整实现SegFormer模型并深入剖析其数据流动过程。1. 环境准备与基础模块搭建在开始构建SegFormer之前我们需要确保开发环境配置正确。推荐使用Python 3.8和PyTorch 1.10版本同时安装必要的依赖库pip install torch torchvision opencv-python numpy matplotlibSegFormer的核心由三个基础模块构成MLP多层感知机、ConvModule卷积模块和SegFormerHead解码头。让我们先实现这些基础组件。1.1 MLP模块实现MLP模块负责特征通道维度的转换其实现如下class MLP(nn.Module): def __init__(self, input_dim2048, embed_dim768): super().__init__() self.proj nn.Linear(input_dim, embed_dim) def forward(self, x): x x.flatten(2).transpose(1, 2) # [B,C,H,W] - [B,H*W,C] x self.proj(x) # 线性变换 x x.transpose(1, 2).unflatten(2, (x.size(1), x.size(2))) # 恢复空间维度 return x这个模块的关键操作包括空间维度展平与恢复线性变换实现通道维度转换保持批处理维度不变1.2 ConvModule模块设计ConvModule是标准的卷积-批归一化-激活函数组合class ConvModule(nn.Module): def __init__(self, in_channels, out_channels, kernel_size1, stride1, padding0, groups1, actTrue): super().__init__() self.conv nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groupsgroups, biasFalse) self.bn nn.BatchNorm2d(out_channels, eps1e-5, momentum0.1) self.act nn.ReLU() if act else nn.Identity() def forward(self, x): return self.act(self.bn(self.conv(x)))提示在实际应用中可以根据需求调整激活函数类型比如使用LeakyReLU或SiLU等。1.3 数据维度转换工具函数为了方便后续处理我们实现几个常用的维度转换函数def reshape_for_mlp(x): 将4D张量转换为适合MLP处理的3D格式 B, C, H, W x.shape return x.flatten(2).transpose(1, 2) # [B,C,H,W] - [B,H*W,C] def reshape_for_conv(x, target_shape): 将MLP输出恢复为4D卷积格式 return x.transpose(1, 2).view(x.size(0), -1, *target_shape[-2:])2. Backbone网络构建SegFormer采用Mix Vision TransformerMiT作为backbone我们以MiT-B0为例进行实现。2.1 Patch Embedding层这是将图像转换为token序列的关键层class OverlapPatchEmbed(nn.Module): def __init__(self, img_size224, patch_size7, stride4, in_chans3, embed_dim768): super().__init__() self.proj nn.Conv2d(in_chans, embed_dim, kernel_sizepatch_size, stridestride, padding(patch_size // 2, patch_size // 2)) self.norm nn.LayerNorm(embed_dim) def forward(self, x): x self.proj(x) # [B,3,H,W] - [B,C,H,W] x x.flatten(2).transpose(1, 2) # [B,C,H,W] - [B,H*W,C] x self.norm(x) return x对于512×512输入图像经过patch embedding后的维度变化为输入[B,3,512,512]输出[B,4096,32]HW128128×128163842.2 Transformer Block实现Transformer Block是MiT的核心组件class EfficientSelfAttention(nn.Module): def __init__(self, dim, num_heads8, qkv_biasFalse, sr_ratio1): super().__init__() self.num_heads num_heads self.head_dim dim // num_heads self.scale self.head_dim ** -0.5 self.q nn.Linear(dim, dim, biasqkv_bias) self.kv nn.Linear(dim, dim * 2, biasqkv_bias) self.proj nn.Linear(dim, dim) self.sr_ratio sr_ratio if sr_ratio 1: self.sr nn.Conv2d(dim, dim, kernel_sizesr_ratio, stridesr_ratio) self.norm nn.LayerNorm(dim) def forward(self, x, H, W): B, N, C x.shape q self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) if self.sr_ratio 1: x_ x.permute(0, 2, 1).reshape(B, C, H, W) x_ self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) x_ self.norm(x_) kv self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) else: kv self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v kv[0], kv[1] attn (q k.transpose(-2, -1)) * self.scale attn attn.softmax(dim-1) x (attn v).transpose(1, 2).reshape(B, N, C) x self.proj(x) return x class MixFFN(nn.Module): def __init__(self, in_features, hidden_featuresNone, out_featuresNone, act_layernn.GELU): super().__init__() out_features out_features or in_features hidden_features hidden_features or in_features self.fc1 nn.Linear(in_features, hidden_features) self.dwconv nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, groupshidden_features) self.act act_layer() self.fc2 nn.Linear(hidden_features, out_features) def forward(self, x, H, W): x self.fc1(x) x self.act(x) x x.transpose(1, 2).view(x.size(0), -1, H, W) x self.dwconv(x) x x.flatten(2).transpose(1, 2) x self.fc2(x) return x class TransformerBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio4., qkv_biasFalse, sr_ratio1): super().__init__() self.norm1 nn.LayerNorm(dim) self.attn EfficientSelfAttention(dim, num_heads, qkv_bias, sr_ratio) self.norm2 nn.LayerNorm(dim) self.mlp MixFFN(dim, int(dim * mlp_ratio)) def forward(self, x, H, W): x x self.attn(self.norm1(x), H, W) x x self.mlp(self.norm2(x), H, W) return x2.3 MiT-B0完整实现整合上述组件我们实现完整的MiT-B0 backboneclass MixVisionTransformer(nn.Module): def __init__(self, img_size512, in_chans3, embed_dims[32, 64, 160, 256], num_heads[1, 2, 5, 8], mlp_ratios[4, 4, 4, 4], qkv_biasTrue, depths[2, 2, 2, 2], sr_ratios[8, 4, 2, 1]): super().__init__() self.depths depths # Patch embeddings self.patch_embed1 OverlapPatchEmbed(img_size, 7, 4, in_chans, embed_dims[0]) self.patch_embed2 OverlapPatchEmbed(img_size//4, 3, 2, embed_dims[0], embed_dims[1]) self.patch_embed3 OverlapPatchEmbed(img_size//8, 3, 2, embed_dims[1], embed_dims[2]) self.patch_embed4 OverlapPatchEmbed(img_size//16, 3, 2, embed_dims[2], embed_dims[3]) # Transformer blocks self.block1 nn.ModuleList([TransformerBlock( embed_dims[0], num_heads[0], mlp_ratios[0], qkv_bias, sr_ratios[0]) for _ in range(depths[0])]) self.block2 nn.ModuleList([TransformerBlock( embed_dims[1], num_heads[1], mlp_ratios[1], qkv_bias, sr_ratios[1]) for _ in range(depths[1])]) self.block3 nn.ModuleList([TransformerBlock( embed_dims[2], num_heads[2], mlp_ratios[2], qkv_bias, sr_ratios[2]) for _ in range(depths[2])]) self.block4 nn.ModuleList([TransformerBlock( embed_dims[3], num_heads[3], mlp_ratios[3], qkv_bias, sr_ratios[3]) for _ in range(depths[3])]) self.norm1 nn.LayerNorm(embed_dims[0]) self.norm2 nn.LayerNorm(embed_dims[1]) self.norm3 nn.LayerNorm(embed_dims[2]) self.norm4 nn.LayerNorm(embed_dims[3]) def forward(self, x): B x.shape[0] outs [] # Stage 1 x, H1, W1 self.patch_embed1(x), x.shape[2]//4, x.shape[3]//4 for blk in self.block1: x blk(x, H1, W1) x self.norm1(x) x x.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) # Stage 2 x, H2, W2 self.patch_embed2(x), H1//2, W1//2 for blk in self.block2: x blk(x, H2, W2) x self.norm2(x) x x.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) # Stage 3 x, H3, W3 self.patch_embed3(x), H2//2, W2//2 for blk in self.block3: x blk(x, H3, W3) x self.norm3(x) x x.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) # Stage 4 x, H4, W4 self.patch_embed4(x), H3//2, W3//2 for blk in self.block4: x blk(x, H4, W4) x self.norm4(x) x x.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() outs.append(x) return outs3. SegFormer解码头实现SegFormer的解码头负责将backbone提取的多尺度特征融合并生成最终的分割结果。3.1 多尺度特征融合class SegFormerHead(nn.Module): def __init__(self, in_channels[32, 64, 160, 256], embedding_dim256, num_classes20): super().__init__() c1, c2, c3, c4 in_channels # 四个MLP用于不同尺度特征转换 self.linear_c4 MLP(c4, embedding_dim) self.linear_c3 MLP(c3, embedding_dim) self.linear_c2 MLP(c2, embedding_dim) self.linear_c1 MLP(c1, embedding_dim) # 特征融合模块 self.fusion ConvModule(embedding_dim*4, embedding_dim, 1) # 最终预测层 self.predictor nn.Conv2d(embedding_dim, num_classes, 1) # 上采样参数 self.upsample4 nn.Upsample(scale_factor8, modebilinear, align_cornersFalse) self.upsample3 nn.Upsample(scale_factor4, modebilinear, align_cornersFalse) self.upsample2 nn.Upsample(scale_factor2, modebilinear, align_cornersFalse) def forward(self, features): c1, c2, c3, c4 features # 转换特征维度 _c4 self.linear_c4(c4) _c3 self.linear_c3(c3) _c2 self.linear_c2(c2) _c1 self.linear_c1(c1) # 上采样到统一尺寸 _c4 self.upsample4(_c4) _c3 self.upsample3(_c3) _c2 self.upsample2(_c2) # 特征融合 fused torch.cat([_c4, _c3, _c2, _c1], dim1) fused self.fusion(fused) # 生成预测 pred self.predictor(fused) return pred3.2 数据流可视化为了更直观理解SegFormer的数据流动我们以512×512输入为例展示各阶段特征图尺寸变化阶段模块输入尺寸输出尺寸关键操作输入-[B,3,512,512]--Stage1PatchEmbed[B,3,512,512][B,32,128,128]Conv7x7, stride4Stage1Transformer×2[B,32,128,128][B,32,128,128]自注意力MLPStage2PatchEmbed[B,32,128,128][B,64,64,64]Conv3x3, stride2Stage2Transformer×2[B,64,64,64][B,64,64,64]自注意力MLPStage3PatchEmbed[B,64,64,64][B,160,32,32]Conv3x3, stride2Stage3Transformer×2[B,160,32,32][B,160,32,32]自注意力MLPStage4PatchEmbed[B,160,32,32][B,256,16,16]Conv3x3, stride2Stage4Transformer×2[B,256,16,16][B,256,16,16]自注意力MLPHeadMLP转换多尺度特征[B,256,128,128]线性变换上采样Head特征融合[B,1024,128,128][B,256,128,128]1x1卷积Head预测[B,256,128,128][B,num_classes,128,128]1x1卷积4. 完整SegFormer模型集成现在我们将backbone和解码头组合成完整的SegFormer模型class SegFormer(nn.Module): def __init__(self, num_classes20, phib0, pretrainedFalse): super().__init__() # Backbone配置 self.backbone_config { b0: { embed_dims: [32, 64, 160, 256], num_heads: [1, 2, 5, 8], depths: [2, 2, 2, 2], sr_ratios: [8, 4, 2, 1] }, b1: { embed_dims: [64, 128, 320, 512], num_heads: [1, 2, 5, 8], depths: [2, 2, 2, 2], sr_ratios: [8, 4, 2, 1] }, # 其他尺寸配置... } # 初始化backbone config self.backbone_config[phi] self.backbone MixVisionTransformer( embed_dimsconfig[embed_dims], num_headsconfig[num_heads], depthsconfig[depths], sr_ratiosconfig[sr_ratios] ) # 初始化解码头 self.decode_head SegFormerHead( in_channelsconfig[embed_dims], embedding_dim256, num_classesnum_classes ) # 加载预训练权重 if pretrained: self.load_pretrained(phi) def forward(self, x): features self.backbone(x) pred self.decode_head(features) return F.interpolate(pred, sizex.shape[2:], modebilinear, align_cornersFalse) def load_pretrained(self, phi): # 这里实现预训练权重加载逻辑 pass5. 模型训练与优化5.1 损失函数设计语义分割常用的损失函数组合class SegLoss(nn.Module): def __init__(self, num_classes, ignore_index255): super().__init__() self.ce_loss nn.CrossEntropyLoss(ignore_indexignore_index) self.dice_loss DiceLoss(num_classes, ignore_index) def forward(self, pred, target): ce self.ce_loss(pred, target) dice self.dice_loss(pred, target) return 0.5 * ce 0.5 * dice5.2 训练配置建议SegFormer训练时推荐使用以下配置def get_optimizer(model, lr6e-5, weight_decay0.01): param_groups [ {params: [p for n, p in model.named_parameters() if backbone in n], lr: lr}, {params: [p for n, p in model.named_parameters() if decode_head in n], lr: lr * 10} ] return torch.optim.AdamW(param_groups, weight_decayweight_decay) def get_scheduler(optimizer, total_epochs): return torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr[group[lr] for group in optimizer.param_groups], total_stepstotal_epochs, pct_start0.05, anneal_strategycos )5.3 数据增强策略针对语义分割任务的有效数据增强train_transform A.Compose([ A.RandomResizedCrop(512, 512, scale(0.5, 2.0)), A.HorizontalFlip(p0.5), A.VerticalFlip(p0.5), A.RandomRotate90(p0.5), A.ColorJitter(brightness0.4, contrast0.4, saturation0.4, hue0.1, p0.8), A.GaussianBlur(blur_limit(3, 7), p0.5), A.Normalize(mean(0.485, 0.456, 0.406), std(0.229, 0.224, 0.225)), ToTensorV2() ])

更多文章