TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

114. Mixed-Precision Adam Memory

Medium

In mixed-precision training (fp16/bf16 forward + bwd, fp32 master copy), compute the bytes per parameter for parameters + gradients + optimizer state.

Signature: def mixed_precision_memory_bytes(n_params: int) -> int

The convention (Megatron / DeepSpeed):

  • fp16 weights used in forward/backward: 2 bytes
  • fp16 gradients: 2 bytes
  • fp32 master weights (kept by the optimizer for stable updates): 4 bytes
  • fp32 Adam first moment m: 4 bytes
  • fp32 Adam second moment v: 4 bytes

Total: 2 + 2 + 4 + 4 + 4 = 16 bytes per parameter.

Return n_params * 16.

Math

Mmixed​=N⋅(2+2+4+4+4)=16⋅N

Asked at

NumPy

import numpy as np

 

def mixed_precision_memory_bytes(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?