TorchedUp
ProblemsPremium
TorchedUp
ZeRO Stage 2 — Gradient ShardingHard
ProblemsPremium

ZeRO Stage 2 — Gradient Sharding

Implement ZeRO Stage 2 gradient averaging: each rank is responsible for averaging only its assigned gradient shard (reduce-scatter), then shares the result (all-gather).

Signature: def zero_stage2(gradients, world_size)

  • gradients: list of world_size arrays, each (grad_size,) — local gradients on each rank
  • world_size: number of ranks
  • Returns: list of world_size arrays — globally averaged gradients (same result on all ranks)

Algorithm: Reduce-Scatter + All-Gather

Stage A — Reduce-Scatter (each rank averages its gradient shard):

shard_size = grad_size // world_size
For rank i:
    start = i * shard_size
    end   = (i+1) * shard_size  # last rank gets remainder
    shard_avg[i] = mean(gradients[0][start:end], ..., gradients[world_size-1][start:end])

Stage B — All-Gather (reconstruct full averaged gradient):

full_grad = concatenate(shard_avg[0], shard_avg[1], ..., shard_avg[world_size-1])
# Return full_grad for all ranks

ZeRO Stages

| Stage | What's sharded | Memory reduction | |-------|----------------|-----------------| | 1 | Optimizer states | ~4× | | 2 | + Gradients | ~8× | | 3 | + Parameters | ~64× |

ZeRO Stage 2 is used by default in DeepSpeed for multi-GPU training.

Asked at

Python (numpy)0/3 runs today

Test Results

○2 ranks, grad=[1,2,3,4] and [5,6,7,8] → averaged [3,4,5,6] on all ranks
○3 ranks with sparse gradients → averaged
○1 rank (no-op): returns input unchanged🔒 Premium
Advertisement