前言

最近遇到一个训练代码,混合精度使用 apex,多卡还是 torch ddp+mp.spawn 子进程启动的方式,性能受限于 python 的 gil 锁。

其实对于混合精度训练 pytorch 已经 merge 进了 amp,fsdp 也支持了 mixed precision policy,多卡训练有 torchrun 启动器,还支持多机分布式。

就在想 pytorch 已经有很多新的 feature,为什么不去用呢,接下来试图用一文说清楚混合精度训练的来龙去脉。

1、起源

17 年,nv 的这篇 paper 提出了混合精度训练:Mixed Precision Training。

https://arxiv.org/abs/1710.03740

【AI大模型面试】阿里二面追问:FP16训练如何避免NaN?看完这一篇你就知道了!!_#人工智能

ref:https://medium.com/data-science/understanding-mixed-precision-training-4b246679c7c4

但用 fp16 去表示 fp32 计算,在训练中会有一些数值问题:

  • 精度下溢/上溢
  • fp16 数值范围和分布不匹配,导致梯度归零

首先第一个问题,fp16 的表示范围小于 fp32,因此会产生 underflow、overflow:

p = torch.tensor([1.0]), device='cuda:0')print(p.dtype, p + 0.0001)p = torch.tensor([1.0]), device='cuda:0').to(torch.float16)print(p.dtype, p + 0.0001)# torch.float32 tensor([1.0001],device='cuda:0')# torch.float16 tensor([1.],device='cuda:0',dtype=torch.float16)a = torch.empty(4096,device='cuda:0').fill_(16.0)print(a.dtype, a.sum())a = torch.tensor(4096,device='cuda:0').to(torch.float16).fill_(16.0)print(a.dtype, a.sum())# torch.float32 tensor(65536.,device='cuda:0')# torch.float16 tensor(inf,device='cuda:0',dtype=torch.float16)

第二个问题,对于 activation gradient 的分布,很大一部分较小的值在 fp16 下是不可表示的,会发生下溢 underflow 被置为 0,导致反向传播中梯度就丢失了。

【AI大模型面试】阿里二面追问:FP16训练如何避免NaN?看完这一篇你就知道了!!_#人工智能_02

ref:https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html#training

混合精度训练中的数值问题和模型量化中遇到的数值问题其实很类似,都是从高精度表示范围映射到低精度表示范围,在量化中是通过 calibration 校准进行不同精度范围的线性映射:scale = x_max/range_max -> x_q = x/scale

在混合精度训练中是引入了 loss scaling 梯度缩放,防止非常小的梯度值在使用 fp16 进行表示时 underflow 下溢变成 0。

在前向传播计算得到损失值 loss,开始反向传播 backward 之前,对 loss 进行缩放,乘以一个大于 1 的常数,称为缩放因子 S,例如 1024、4096 等。

然后用缩放后的 loss 再去进行 backward,又由于求导是基于链式法则,反向传播过程中所有的梯度值都会被进行同等缩放。

但放大的梯度导致后续所有依赖梯度大小进行计算的操作都会失真,权重更新时 w=w-lr*grad,导致权重更新量也会被放大 S 倍,相当于变相增大了学习率 lr,导致训练过程不稳定,且与 fp32 训练的行为不一致。

因此在 backward 之后,更新模型权重之前,包括所有需要依赖梯度大小进行计算的操作之前。

要在 fp32 精度下对权重梯度进行反缩放 unscaling,除以之前放大的缩放因子 S,unscaling 之后的权重梯度也可以应用梯度裁剪和权重衰减等依赖于梯度大小值的操作。

然后 optimizer 里面会 copy 一份 fp32 的主权重 master weights 进行参数更新,更新之后的 master weights 再 cast 到 fp16 同步给模型参数。

【AI大模型面试】阿里二面追问:FP16训练如何避免NaN?看完这一篇你就知道了!!_#职场和发展_03

ref:https://arxiv.org/abs/1710.03740

整个混合精度训练的流程可以表示为:

  • model fp16 weights+fp16 activations
  • fp16 forward
  • fp16 activations
  • fp32 loss scaling
  • weight&activation fp16 backward
  • fp32 grad unscaling
  • fp32 grad clip&decay
  • fp32 master weights update
  • cast to fp16 model weights

2、amp

18 年 nv 以 pytorch 三方扩展的形式推出了 apex,以支持混合精度,20 年 pytorch1.6 merge 进了 torch.cuda.amp,配合 autocast 实现混合精度训练:

scaler = GradScaler()for epoch in epochs:    for input, target in data:        optimizer.zero_grad()        with autocast(device_type='cuda', dtype=torch.float16):            output = model(input)            loss = loss_fn(output, target)        scaler.scale(loss).backward()        # Unscales the gradients of optimizer's assigned params in-place        scaler.unscale_(optimizer)        # Since the gradients of optimizer's assigned params are unscaled, clips as usual:        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)        # optimizer's gradients are already unscaled, so scaler.step does not unscale them,        # although it still skips optimizer.step() if the gradients contain infs or NaNs.        scaler.step(optimizer)        # Updates the scale for next iteration.        scaler.update()

autocast 是混合精度的上下文管理器,在 context 里面会自动选择 op 对应的计算精度,主要基于白名单机制进行自动类型转换:

【AI大模型面试】阿里二面追问:FP16训练如何避免NaN?看完这一篇你就知道了!!_#大模型学习_04

3、fsdp

fsdp 是 21 年 pytorch1.11 的引入的新特性,核心思想来源于 deepspeed zero,在其上又扩展了对混合精度的支持。

【AI大模型面试】阿里二面追问:FP16训练如何避免NaN?看完这一篇你就知道了!!_#AI大模型_05

ref:https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html

与 amp 相比,fsdp 灵活度更高,可以通过 FSDPModule warp 设置不同的混合精度策略。

以及为了更高精度的数值结果,在 fp16 activation 计算和模型参数 all-gather 的基础上,可以使用 fp32 进行 gradient 的 reduce-scatter 和 optimizer 的参数更新:

fsdp_kwargs = {    "mp_policy": MixedPrecisionPolicy(        param_dtype=torch.bfloat16,        reduce_dtype=torch.float32,    )}

QA: 这个时候再思考下,数值计算在什么情况下会发生 nan,迁移到混合精度训练流程里面哪些地方可能会产生 nan,怎么检测 nan,怎么避免 nan,哪些是框架已经做的,哪些还需要自己处理的,问问自己是否有了答案?