Implement ZeRO Stage 2 gradient averaging: each rank is responsible for averaging only its assigned gradient shard (reduce-scatter), then shares the result (all-gather).
Signature: def zero_stage2(gradients, world_size)
gradients: list of world_size arrays, each (grad_size,) — local gradients on each rankworld_size: number of ranksworld_size arrays — globally averaged gradients (same result on all ranks)Stage A — Reduce-Scatter (each rank averages its gradient shard):
shard_size = grad_size // world_size
For rank i:
start = i * shard_size
end = (i+1) * shard_size # last rank gets remainder
shard_avg[i] = mean(gradients[0][start:end], ..., gradients[world_size-1][start:end])
Stage B — All-Gather (reconstruct full averaged gradient):
full_grad = concatenate(shard_avg[0], shard_avg[1], ..., shard_avg[world_size-1])
# Return full_grad for all ranks
| Stage | What's sharded | Memory reduction | |-------|----------------|-----------------| | 1 | Optimizer states | ~4× | | 2 | + Gradients | ~8× | | 3 | + Parameters | ~64× |
ZeRO Stage 2 is used by default in DeepSpeed for multi-GPU training.
Asked at
Test Results