Mixed precision training uses fp16 for the forward and backward passes (faster, less GPU memory) but fp32 for the optimizer update (prevents catastrophic precision loss in weight updates).
The key challenge: fp16 has a tiny dynamic range. Gradients near zero get flushed to exactly 0 ("underflow"), killing learning. The fix is loss scaling: multiply the loss by a large constant before backward, then divide gradients by the same constant before the optimizer step.
Steps:
scaled_loss = loss * loss_scalescaled_grad = grad * loss_scaleunscaled_grad = scaled_grad / loss_scaleparams = params - lr * unscaled_gradSignature: def mixed_precision_step(params_fp32, grad_fp32, loss_scale, lr)
params_fp32: (N,) current fp32 parametersgrad_fp32: (N,) gradients in fp32 (before overflow check)loss_scale: float — scale factor used (check overflow by scaling grad)lr: float — learning rate(new_params, skipped) — updated params and bool (True if skipped due to inf/nan)Math
Asked at
Test Results