TorchedUp
ProblemsPremium
TorchedUp
Mixed Precision Training StepMedium
ProblemsPremium

Mixed Precision Training Step

Mixed precision training uses fp16 for the forward and backward passes (faster, less GPU memory) but fp32 for the optimizer update (prevents catastrophic precision loss in weight updates).

The key challenge: fp16 has a tiny dynamic range. Gradients near zero get flushed to exactly 0 ("underflow"), killing learning. The fix is loss scaling: multiply the loss by a large constant before backward, then divide gradients by the same constant before the optimizer step.

Steps:

  1. Cast activations to fp16, run forward pass
  2. Scale loss: scaled_loss = loss * loss_scale
  3. Backward produces scaled gradients: scaled_grad = grad * loss_scale
  4. Unscale: unscaled_grad = scaled_grad / loss_scale
  5. Check for inf/nan in scaled gradients — if found, skip the update (gradient overflow)
  6. fp32 SGD update: params = params - lr * unscaled_grad

Signature: def mixed_precision_step(params_fp32, grad_fp32, loss_scale, lr)

  • params_fp32: (N,) current fp32 parameters
  • grad_fp32: (N,) gradients in fp32 (before overflow check)
  • loss_scale: float — scale factor used (check overflow by scaling grad)
  • lr: float — learning rate
  • Returns: (new_params, skipped) — updated params and bool (True if skipped due to inf/nan)

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○normal update
○inf gradient — skip update
○nan gradient — skip update🔒 Premium
Advertisement