PyTorch实战:从零搭建Mask R-CNN模型并优化COCO数据集训练

张开发
2026/4/12 21:37:28 15 分钟阅读

分享文章

PyTorch实战:从零搭建Mask R-CNN模型并优化COCO数据集训练
1. 环境配置与源码准备第一次接触Mask R-CNN时我也被复杂的依赖关系搞得头大。经过多次实践我总结出一套最稳定的环境配置方案。建议使用Python 3.8和PyTorch 1.10的组合这个版本区间兼容性最好。先创建一个干净的conda环境conda create -n maskrcnn python3.8 conda activate maskrcnn安装PyTorch时要注意CUDA版本匹配。如果你的显卡是30系列建议这样安装pip install torch1.10.0cu113 torchvision0.11.1cu113 -f https://download.pytorch.org/whl/torch_stable.html源码推荐使用MMDetection框架的实现比原生PyTorch版本更易用git clone https://github.com/open-mmlab/mmdetection cd mmdetection pip install -v -e .这里有个坑要注意MMCV的版本必须与PyTorch严格匹配。我测试下来最稳定的组合是pip install mmcv-full1.6.02. COCO数据集处理技巧COCO2017数据集有超过12万张图片解压后约25GB。我建议这样组织目录结构mmdetection ├── data │ └── coco │ ├── annotations │ ├── train2017 │ └── val2017处理标注文件时我发现一个常见问题有些开发者会遗漏关键步骤。正确的做法是将instances_train2017.json和instances_val2017.json放在annotations目录确保图片文件名与标注文件中的image_id对应验证数据集完整性from pycocotools.coco import COCO coco COCO(data/coco/annotations/instances_train2017.json) print(len(coco.getImgIds())) # 应该输出118287如果遇到内存不足的问题可以修改configs/base/datasets/coco_detection.py中的ImageToTensor变换添加to_float32False参数减少内存占用。3. 模型配置与训练优化Mask R-CNN的配置文件位于configs/mask_rcnn目录。我强烈建议先复制一份默认配置cp configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco_custom.py关键修改点包括修改num_classes参数COCO默认80类调整学习率策略optimizer dict( typeSGD, lr0.02, # 8GPU时的基准学习率 momentum0.9, weight_decay0.0001) optimizer_config dict(grad_clipNone)单卡训练时需要按比例降低学习率lr 0.02 / 8 # 单卡学习率训练命令推荐使用分布式训练即使只有一张卡./tools/dist_train.sh configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco_custom.py 14. 训练监控与调试技巧训练过程中我习惯用TensorBoard监控指标tensorboard --logdirwork_dirs几个关键指标需要特别关注loss_rpn_cls建议值0.01-0.05loss_mask稳定在0.2左右较理想mAP0.5:0.95COCO基准应在0.35以上如果遇到NaN损失可以尝试降低学习率添加梯度裁剪optimizer_config dict(grad_clipdict(max_norm35, norm_type2))5. 模型评估与结果可视化评估模型性能时我发现官方提供的test.py脚本有些参数很实用python tools/test.py \ configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco_custom.py \ work_dirs/latest.pth \ --eval bbox segm \ --show-dir results可视化结果时可以修改mmdet/core/visualization/image.py中的draw_masks函数调整mask的透明度def draw_masks(...): masks masks.astype(np.bool) colors [np.array((0, 255, 0))] # 修改mask颜色 alpha 0.5 # 调整透明度6. 小数据集训练策略当使用气球这类小数据集时我总结出几个有效技巧数据增强配置train_pipeline [ dict(typeLoadImageFromFile), dict(typeLoadAnnotations, with_bboxTrue, with_maskTrue), dict(typeResize, img_scale(1333, 800), keep_ratioTrue), dict(typeRandomFlip, flip_ratio0.5), dict(typeNormalize, ...), dict(typePad, size_divisor32), dict(typeDefaultFormatBundle), dict(typeCollect, keys[img, gt_bboxes, gt_labels, gt_masks]), ]冻结骨干网络前几层model dict( backbonedict( frozen_stages2, # 冻结前2个stage norm_cfgdict(requires_gradFalse)))使用更小的batch_size和更长的训练周期data dict( samples_per_gpu1, # 单卡batch_size workers_per_gpu2) runner dict(typeEpochBasedRunner, max_epochs100)7. 常见问题解决方案在项目实践中我遇到过这些典型问题问题1RuntimeError: CUDA out of memory解决方案减小batch_size使用更小的输入尺寸img_norm_cfg dict( mean[123.675, 116.28, 103.53], std[58.395, 57.12, 57.375], to_rgbTrue) train_pipeline [ dict(typeResize, img_scale(800, 600), keep_ratioTrue), ... ]问题2验证集指标波动大解决方案增加验证间隔evaluation dict(interval2, metric[bbox, segm])使用更稳定的优化器optimizer dict( typeAdamW, lr0.0001, weight_decay0.0001)问题3预测时出现重复框解决方案调整NMS阈值model dict( test_cfgdict( rcnndict( score_thr0.05, nmsdict(typenms, iou_threshold0.5), max_per_img100)))8. 进阶优化技巧经过多次实验我发现这些优化手段效果显著使用Swin Transformer作为backbonemodel dict( backbonedict( typeSwinTransformer, embed_dims96, depths[2, 2, 6, 2], num_heads[3, 6, 12, 24], window_size7, mlp_ratio4, qkv_biasTrue, qk_scaleNone, drop_rate0., attn_drop_rate0., drop_path_rate0.2, patch_normTrue), neckdict(...))添加注意力机制model dict( neckdict( typeFPN, in_channels[256, 512, 1024, 2048], out_channels256, num_outs5, add_extra_convson_output, relu_before_extra_convsTrue), rpn_headdict( typeRPNHead, in_channels256, feat_channels256, anchor_generatordict(...), loss_clsdict(...), loss_bboxdict(...)), roi_headdict( typeStandardRoIHead, bbox_roi_extractordict(...), bbox_headdict( typeShared2FCBBoxHead, in_channels256, fc_out_channels1024, roi_feat_size7, num_classes80, bbox_coderdict(...), reg_class_agnosticFalse, loss_clsdict(...), loss_bboxdict(...)), mask_roi_extractordict(...), mask_headdict( typeFCNMaskHead, num_convs4, in_channels256, conv_out_channels256, num_classes80, loss_maskdict(...))))使用混合精度训练fp16 dict(loss_scale512.)这些配置需要根据具体硬件条件调整建议先在小型数据集上测试效果。我在实际项目中通过这些优化将mAP提升了约15%。

更多文章