混合精度训练完全指南:从理论到实践
混合精度是大模型训练的标配技术。本文从原理到实践,全面讲解混合精度训练。
一、什么是混合精度?
定义
混合精度(Mixed Precision):
- 同时使用 FP16 和 FP32 两种精度
- 前向/反向传播用 FP16
- 权重更新用 FP32
精度对比
| 精度 | 位数 | 范围 | 精度 |
|---|---|---|---|
| FP32 | 32 | ±3.4e38 | 7 位有效数字 |
| FP16 | 16 | ±65504 | 3-4 位有效数字 |
| BF16 | 16 | ±3.4e38 | 7 位有效数字 |
优势
显存:
- FP16 占用减半
- 支持更大 batch size
- 支持更大模型
- Tensor Core 加速
- 带宽需求减半
- 训练速度提升 2-3 倍
二、实现原理
基本流程
1. FP32 权重 → 复制 → FP16 权重
2. FP16 前向传播 → 计算 Loss
3. FP16 反向传播 → 计算 FP16 梯度
4. FP16 梯度 → 转换 → FP32 梯度
5. FP32 梯度 → 更新 → FP32 权重关键问题
1. 下溢(Underflow)
- FP16 范围小,小梯度可能丢失
- 解决:Loss Scaling
- FP16 范围小,大值可能溢出
- 解决:梯度裁剪
- 累加操作精度损失
- 解决:关键操作保留 FP32
三、PyTorch 实现
基础用法(AMP)
python
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()关键组件
autocast:
- 自动选择精度
- 白名单操作(FP16 加速)
- 黑名单操作(FP32 保证精度)
- 动态 Loss Scaling
- 防止梯度下溢
- 自动调整 scale 因子
进阶配置
python
from torch.cuda.amp import GradScaler, autocast
# 自定义 scaler 配置
scaler = GradScaler(
init_scale=65536.0,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=2000
)
# 自定义 autocast 策略
with autocast(dtype=torch.float16):
# 强制某些操作使用 FP32
with autocast(enabled=False):
critical_op = fp32_function(x)四、BF16 vs FP16
对比
| 特性 | FP16 | BF16 |
|---|---|---|
| 指数位 | 5 | 8 |
| 尾数位 | 10 | 7 |
| 动态范围 | 小 | 大(同 FP32) |
| 精度 | 高 | 低 |
| 适用场景 | 通用 | 大模型训练 |
选择建议
FP16 适合:
- 通用场景
- A100 之前 GPU
- 对精度要求高
- 大模型训练
- A100/H100 GPU
- 稳定性优先
BF16 实现
python
# PyTorch 1.10+ 支持 BF16
model = model.bfloat16()
# 或使用 autocast
with autocast(dtype=torch.bfloat16):
outputs = model(inputs)五、常见问题与解决
问题 1:Loss 变成 NaN
原因:
- 梯度溢出
- Learning rate 太大
- Loss Scaling 不合适
python
# 降低学习率
optimizer = Adam(model.parameters(), lr=1e-5)
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# 调整 scaler
scaler = GradScaler(init_scale=1024.0)问题 2:精度下降明显
原因:
- 某些操作不适合 FP16
- 累加次数过多
python
# 关键操作使用 FP32
with autocast(enabled=False):
accurate_sum = fp32_accumulate(x)
# 减少累加次数
# 使用梯度累积代替大 batch问题 3:速度提升不明显
原因:
- 数据加载瓶颈
- CPU 瓶颈
- 网络通信瓶颈
- 优化数据管道
- 增加 num_workers
- 检查 GPU 利用率
六、框架支持
PyTorch
版本要求: 1.6+
API: torch.cuda.amp
特点: 灵活、易用
TensorFlow
版本要求: 2.4+
API: tf.keras.mixed_precision
配置:
python
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')DeepSpeed
支持: ZeRO + 混合精度
配置:
json
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
}
}七、最佳实践
1. 逐步迁移
步骤:
2. 监控指标
关键指标:
- Loss 曲线(对比 FP32)
- 梯度范数
- 显存占用
- 训练速度
3. 调优技巧
- 从保守配置开始(低 init_scale)
- 逐步增加 batch size
- 配合梯度累积
- 定期保存 checkpoint
总结
混合精度训练是大模型训练的必备技能:
建议: 从现在开始,所有训练任务默认开启混合精度。
*有问题欢迎交流讨论!*