DeepSpeed ZeRO shards optimizer state, gradients, and parameters across data-parallel GPUs. Compute the per-GPU memory in bytes for parameters + grads + optimizer state for a given ZeRO stage.
Signature: def zero_memory_per_gpu_bytes(n_params: int, world_size: int, stage: int) -> int
Use mixed-precision conventions:
m + Adam v: 4 + 4 + 4 = 12 bytes/paramStage 1 (shard optimizer state only):
bytes = N * (2 + 2) + N * 12 / world_size
Stage 2 (shard optimizer state + gradients):
bytes = N * 2 + N * (2 + 12) / world_size
Stage 3 (shard everything):
bytes = N * (2 + 2 + 12) / world_size = N * 16 / world_size
Return an integer (use integer division).
Math
Asked at
Test Results