TorchedUp
ProblemsPremium
TorchedUp
ZeRO Memory per GPUMedium
ProblemsPremium

ZeRO Memory per GPU

DeepSpeed ZeRO shards optimizer state, gradients, and parameters across data-parallel GPUs. Compute the per-GPU memory in bytes for parameters + grads + optimizer state for a given ZeRO stage.

Signature: def zero_memory_per_gpu_bytes(n_params: int, world_size: int, stage: int) -> int

Use mixed-precision conventions:

  • fp16 weights: 2 bytes/param
  • fp16 grads: 2 bytes/param
  • fp32 master + Adam m + Adam v: 4 + 4 + 4 = 12 bytes/param

Stage 1 (shard optimizer state only):

bytes = N * (2 + 2) + N * 12 / world_size

Stage 2 (shard optimizer state + gradients):

bytes = N * 2 + N * (2 + 12) / world_size

Stage 3 (shard everything):

bytes = N * (2 + 2 + 12) / world_size = N * 16 / world_size

Return an integer (use integer division).

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○stage 1, 1B params, W=8
○stage 2, 1B params, W=8
○stage 3, 1B params, W=8
○stage 3, 7B params, W=64🔒 Premium
Advertisement