Mixed precision training uses fp16 for the forward and backward passes (faster, less GPU memory) but fp32 for the optimizer update (prevents catastrophic precision loss in weight updates).
The key challenge: fp16 has a tiny dynamic range. Gradients near zero get flushed to exactly 0 ("underflow"), killing learning. The fix is loss scaling: multiply the loss by a large constant before backward, then divide gradients by the same constant before the optimizer step.
Steps:
loss_scale before calling backward, so the resulting gradients are also scaled by loss_scale (this lifts tiny gradients out of fp16's underflow region).loss_scale to recover the true (unscaled) gradient.Signature: def mixed_precision_step(params_fp32, grad_fp32, loss_scale, lr)
params_fp32: (N,) current fp32 parametersgrad_fp32: (N,) gradients in fp32 (before overflow check)loss_scale: float — scale factor used (check overflow by scaling grad)lr: float — learning rate(new_params, skipped) — updated params and bool (True if skipped due to inf/nan)Math
Asked at
import numpy as np
def mixed_precision_step(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?