给你一张 GPU(比如 A100)参数,让你自己算出 flash attention “最优 tile size”

张开发
2026/4/14 9:00:52 15 分钟阅读

分享文章

给你一张 GPU(比如 A100)参数,让你自己算出 flash attention “最优 tile size”
可以。我们就拿A100做一个“工程上够用的手算版”。目标不是算到绝对最优而是学会这套方法给定 GPU 资源怎么把 FlashAttention 的 tile size 收敛到一个合理区间。我先说结论再带你推。一、先给一个工程结论对A100这类卡FlashAttention 风格 kernel 常见会落在这类量级BLOCK_M 64 或 128BLOCK_N 64 或 128BLOCK_K head_dim 分块后的一个合适宽度常见 32 / 64 / 128总线程数常见128或256这不是拍脑袋而是被四个东西卡出来的shared memory 容量register 压力warp 数量load / compute 是否能重叠下面我们一步一步算。二、先列 A100 里你最该关心的资源你现在先只记这几个不要全背芯片规格。对一个SM你可以近似用这组数来思考每个 SM 最多约2048 threads每个 SM 最多64 warps每个 SM 可用寄存器总量约65536 个 32-bit registersshared memory 可用量大约按160KB 级别去想就够了你做 tile 设计时本质是在问一个 block 占多少 shared memory一个 thread 占多少 register这样一个 SM 能同时挂几个 block / 几个 warp三、先从最硬的约束开始shared memoryFlashAttention 的一个 block最少要反复缓存这些东西中的一部分K_tileV_tile有时还要双缓冲再加上一些辅助空间所以第一步通常不是想 warp而是先看tile 放不放得下 shared memory1. 一个最简单的估算模型假设数据类型是 FP16/BF16每个元素按2 bytes算。如果我们先取一个候选BLOCK_N 128HEAD_DIM 64那么一块K_tile的大小大约是128×64×216384 bytes16KB128 \times 64 \times 2 16384 \text{ bytes} 16KB128×64×216384bytes16KB同样一个V_tile也是 16KB。所以单缓冲只放K和V16KB16KB32KB16KB 16KB 32KB16KB16KB32KB如果做双缓冲32KB×264KB32KB \times 2 64KB32KB×264KB这时候你就会发现BLOCK_N128, d64在 shared memory 上是很舒服的。2. 如果 head_dim 128 呢再算一次128×128×232768 bytes32KB128 \times 128 \times 2 32768 \text{ bytes} 32KB128×128×232768bytes32KB那么K_tile 32KBV_tile 32KB单缓冲就已经64KB64KB64KB双缓冲就是128KB128KB128KB这已经很紧了但在 A100 上仍然是“可能可做”的量级。所以你会立刻得到一个重要直觉head_dim 从 64 变成 128tile 选择空间会骤缩。这也是为什么很多高性能 kernel 不会把所有维度都贪大。四、第二个约束register这一步是最容易把 occupancy 干爆的地方。FlashAttention 里每个 thread 不只是搬数据它还要持有输出累加器acc在线 softmax 的m在线 softmax 的lQ 的局部片段临时中间值所以 tile 一大thread 手上的活就多register 就上升。1. 一个工程上的粗判断你先不要追求精确公式先记这个经验如果每 thread 只用30~50 registers通常还比较从容如果到64~96 registers就开始明显压 occupancy如果破百很多设计就会变得危险2. 为什么 tile 大会推高 register因为 tile 大意味着每个 thread 常常要负责更多输出元素要保留更多 partial sums要保留更多 softmax 状态比如一个 thread 如果只算 1 个输出值register 压力小。但如果一个 thread 要算 4 个、8 个甚至更多输出值寄存器就会往上长。所以第二条结论是tile 不能只看 shared memory还要看每个 thread 背了多少累加器。五、第三个约束warp 数和 block size这一步把执行形态定住。A100 上经验上你通常会先从这两个 block size 开始试128 threads/block256 threads/block原因很直接128 4 warps256 8 warps这两个量级比较容易同时满足warp 数够调度灵活不至于让单 block 太肥为什么不先上 512因为 FlashAttention 这种 kernel 通常已经很吃shared memoryregisterblock 再大很容易导致一个 SM 同时只能挂 1 个 blockoccupancy 下去overlap 变差所以第三条结论是先把 block size 限在 128 或 256是很稳的起点。六、第四个约束让 compute 足够盖住 IO这是最像“真正最优 tile”的那一步。FlashAttention 不是只要能跑就行而是希望算当前 tile 的时间差不多能覆盖下一 tile 的搬运时间。也就是你前面问的 streaming / double buffering 那件事。如果 tile 太小会发生什么一次 load 很快做完一次 compute 也很快做完但 compute 太短盖不住 HBM latencykernel 还是偏 IO-bound如果 tile 太大compute 变长了但 shared memory / register 爆了occupancy 掉了反而不一定更快所以真正的“最好”不是越大越好而是大到足以让 compute 覆盖 IO但又没把资源压爆。七、现在真的开始“手算一个候选配置”我们做一个 A100 上很典型的场景dtype FP16/BF16head_dim 64目标选一组初始 tile方案 A取BLOCK_M 128BLOCK_N 128d 64threads 1281. shared memory 估算K_tile:128×64×216KB128 \times 64 \times 2 16KB128×64×216KBV_tile:128×64×216KB128 \times 64 \times 2 16KB128×64×216KB单缓冲总共32KB32KB32KB双缓冲64KB64KB64KB再加一点额外辅助空间还是很合理。2. warp 数128 threads 4 warps可以但不算很多。3. register如果每个 thread 负责的输出块别太大通常还能维持在一个可接受区间。4. 结论这是一个非常合理的起点。方案 B取BLOCK_M 64BLOCK_N 128d 64threads 128这时候 shared memory 几乎不变因为主要缓存的还是K_tile/V_tile而它们更直接受BLOCK_N × d影响。但BLOCK_M变小以后每个 block 负责的 Q 行更少每 thread 的累加压力可能下降register 可能更友好但单 block 的计算量也变少所以这类配置常常会在更高 occupancy更低单块复用之间做平衡。结论如果128×128太吃 register64×128常常是下一档非常自然的选择。方案 C取BLOCK_M 128BLOCK_N 64d 64这时候 shared memory 直接更轻64×64×28KB per tile64 \times 64 \times 2 8KB \text{ per tile}64×64×28KBper tile两个 tile 单缓冲也才 16KB双缓冲也很轻松。但问题是BLOCK_N小了每次沿 K/V 方向推进得更短HBM 往返次数增多streaming 效率可能下降结论这类配置更稳、更保守但可能不如128宽度的 tile 吃满带宽。八、所以怎么“选最优”而不是“拍一个”真正的流程一般是第一步先定 block threads先试128256第二步shared memory 先过筛把明显放不下的 tile 直接删掉。第三步register 再过筛看编译后每 thread register 用量删掉 occupancy 太差的。第四步实测最后在这几组候选里 benchmark64×6464×128128×64128×128因为真正最优会受很多细节影响head_dim 是 64 还是 128causal mask 有无batch sizeseq len编译器生成的指令质量所以“最优 tile”不是纯数学算出来的而是先靠资源模型缩到 3~6 个候选再跑出来。九、给你一个非常实用的“起手模板”如果你自己在 A100 上做 FlashAttention 风格 kernel可以这样起手当head_dim 64先试BLOCK_M 128, BLOCK_N 128, threads 128BLOCK_M 64, BLOCK_N 128, threads 128BLOCK_M 128, BLOCK_N 64, threads 128当head_dim 128先试BLOCK_M 64, BLOCK_N 128, threads 128BLOCK_M 64, BLOCK_N 64, threads 128BLOCK_M 128, BLOCK_N 64, threads 128因为这时 shared memory 和 register 压力都比d64大很多不能像前者那样激进。十、你可以把这件事压缩成一句话A100 上选 FlashAttention tile本质是先让 K/V tile 能安全放进 shared memory再让每 thread 的累加器别把 register 撑爆最后让 128 或 256 个线程形成足够的 warp 并行来盖住 IO。十一、再给你一个最短判断法以后你看到一组 tile脑子里就问这四句K_tile和V_tile放进 shared memory 后还能留双缓冲空间吗每个 thread 要背多少 acc / m / l / q fragmentregister 会不会爆block 是 4 warp 还是 8 warp够不够调度这一块的 compute 时间能不能大致盖住下一块的 load只要这四句能答得顺tile 就八九不离十了。下一步最合适的是我直接拿一个具体配置比如A100 head_dim128 seq_len4096带你从头手算一遍候选 tile并解释为什么某几组会被淘汰。

更多文章