TSM行为识别实战:从UCF101数据集准备到模型训练,保姆级避坑指南

张开发
2026/4/8 11:11:15 15 分钟阅读

分享文章

TSM行为识别实战:从UCF101数据集准备到模型训练,保姆级避坑指南
TSM行为识别实战从UCF101数据集准备到模型训练全流程解析在计算机视觉领域行为识别一直是一个极具挑战性的研究方向。不同于静态图像分类视频行为识别需要模型能够理解时间维度上的动作变化。TSMTemporal Shift Module作为近年来备受关注的时间建模方法通过巧妙地在2D CNN中引入时序建模能力在保持高效推理的同时显著提升了行为识别性能。本文将手把手带你完成从UCF101数据集准备到TSM模型训练的全过程特别针对实际项目中容易遇到的坑点提供解决方案。1. UCF101数据集准备与预处理UCF101作为行为识别领域的基准数据集包含101类动作的13320个视频片段。但在实际应用中我们往往不需要使用全部类别。以下是经过优化的数据处理流程1.1 数据集下载与结构分析首先从官网获取UCF101数据集其目录结构通常如下UCF101/ ├── ApplyEyeMakeup/ │ ├── v_ApplyEyeMakeup_g01_c01.avi │ └── ... ├── ApplyLipstick/ │ ├── v_ApplyLipstick_g01_c01.avi │ └── ... └── ...关键注意事项视频文件以.avi格式存储每个子目录对应一个动作类别文件名包含视频属性信息如g01表示组别c01表示摄像机角度1.2 视频抽帧处理TSM等行为识别模型通常以视频帧序列作为输入。我们使用FFmpeg进行抽帧import os import subprocess def extract_frames(video_path, output_dir, fps25, size(340, 256)): 使用FFmpeg从视频中提取帧 if not os.path.exists(output_dir): os.makedirs(output_dir) cmd fffmpeg -i {video_path} -r {fps} -s {size[0]}x{size[1]} -q:v 2 {output_dir}/image_%05d.jpg subprocess.call(cmd, shellTrue)常见问题解决方案抽帧速度慢降低目标分辨率或帧率内存不足分批处理视频文件时间戳错乱确保使用%05d格式保证文件名排序正确1.3 生成标签文件UCF101官方提供了训练/测试划分文件。我们需要将其转换为模型可读的格式def generate_label_file(video_dir, output_path): classes sorted(os.listdir(video_dir)) class_to_idx {cls_name: i for i, cls_name in enumerate(classes)} with open(output_path, w) as f: for cls_name in classes: cls_dir os.path.join(video_dir, cls_name) for video in os.listdir(cls_dir): frame_dir os.path.join(cls_dir, video.split(.)[0]) frame_count len(os.listdir(frame_dir)) f.write(f{frame_dir} {frame_count} {class_to_idx[cls_name]}\n)提示对于大型数据集建议使用多进程加速标签生成过程2. TSM模型环境配置2.1 依赖安装推荐使用conda创建Python 3.7环境conda create -n tsm python3.7 conda activate tsm pip install torch1.8.0 torchvision0.9.0 pip install opencv-python ffmpeg-python2.2 代码库克隆与修改从官方仓库克隆TSM代码git clone https://github.com/mit-han-lab/temporal-shift-module.git cd temporal-shift-module关键修改点在dataset_config.py中更新数据集路径根据GPU内存调整batch_size参数修改main.py中的学习率调度策略3. 模型训练与调优3.1 训练命令解析基础训练命令示例python main.py ucf101 RGB \ --arch resnet50 \ --num_segments 8 \ --lr 0.001 \ --epochs 50 \ --batch-size 32 \ --dropout 0.5 \ --consensus_typeavg \ --shift --shift_div8 --shift_placeblockres \ --tune_frompretrained/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e50.pth参数说明num_segments输入视频划分的片段数shift_div控制时序移位比例tune_from预训练权重路径3.2 常见训练问题解决问题1预训练权重加载失败解决方案修改模型加载逻辑# 在main.py中找到权重加载部分添加以下处理 if module. in list(sd.keys())[0]: sd {k.replace(module., ): v for k,v in sd.items()}问题2GPU内存不足优化策略减小batch_size使用梯度累积尝试更小的模型架构如MobileNetV2问题3过拟合应对方法增加dropout值0.5-0.8添加数据增强随机裁剪、水平翻转等使用早停策略4. 模型测试与部署4.1 视频推理实现官方代码缺少直接的视频输入接口我们需要自行实现import torch import cv2 import numpy as np from models import TSN def video_inference(model, video_path, num_segments8): # 初始化模型 model.eval() # 视频读取 cap cv2.VideoCapture(video_path) frames [] while cap.isOpened(): ret, frame cap.read() if not ret: break frame preprocess(frame) # 预处理函数 frames.append(frame) # 分段采样 segment_length len(frames) // num_segments indices [i*segment_length for i in range(num_segments)] inputs torch.stack([frames[i] for i in indices]) # 推理 with torch.no_grad(): outputs model(inputs.unsqueeze(0)) return outputs.argmax().item()4.2 性能优化技巧时序插值对于短视频使用线性插值增加帧数多裁剪测试采用5-crop四角中心提升准确率模型量化使用PyTorch的量化模块减小模型体积5. 进阶应用与扩展5.1 自定义数据集训练当需要处理非UCF101数据时注意保持与UCF101相同的目录结构确保视频长度适中建议3-10秒类别数量变化时需要修改最后的全连接层5.2 与其他模型的对比模型准确率(UCF101)参数量推理速度(FPS)TSN88.5%23.5M45TSM94.2%24.3M52I3D89.4%12.3M285.3 实际应用建议工业场景优先考虑MobileNetV2架构的TSM平衡精度与速度研究场景尝试不同的shift_div参数4/8/16研究时序建模影响边缘设备结合TVM等工具进行模型编译优化在完成基础训练后可以进一步探索多模态融合RGB光流自监督预训练长视频时序建模改进

更多文章