TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

117. ZeRO Memory per GPU

Medium

DeepSpeed ZeRO shards optimizer state, gradients, and parameters across data-parallel GPUs. Compute the per-GPU memory in bytes for parameters + grads + optimizer state for a given ZeRO stage.

Signature: def zero_memory_per_gpu_bytes(n_params: int, world_size: int, stage: int) -> int

Use mixed-precision conventions:

  • fp16 weights: 2 bytes/param
  • fp16 grads: 2 bytes/param
  • fp32 master + Adam m + Adam v: 4 + 4 + 4 = 12 bytes/param

Stage 1 (shard optimizer state only):

bytes = N * (2 + 2) + N * 12 / world_size

Stage 2 (shard optimizer state + gradients):

bytes = N * 2 + N * (2 + 12) / world_size

Stage 3 (shard everything):

bytes = N * (2 + 2 + 12) / world_size = N * 16 / world_size

Return an integer (use integer division).

Math

M1​=4N+W12N​,M2​=2N+W14N​,M3​=W16N​

Asked at

NumPy

import numpy as np

 

def zero_memory_per_gpu_bytes(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?