In mixed-precision training (fp16/bf16 forward + bwd, fp32 master copy), compute the bytes per parameter for parameters + gradients + optimizer state.
Signature: def mixed_precision_memory_bytes(n_params: int) -> int
The convention (Megatron / DeepSpeed):
m: 4 bytesv: 4 bytesTotal: 2 + 2 + 4 + 4 + 4 = 16 bytes per parameter.
Return n_params * 16.
Math
Asked at
import numpy as np
def mixed_precision_memory_bytes(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?