Implement ZeRO Stage 2 gradient averaging: each rank is responsible for averaging only its assigned gradient shard (reduce-scatter), then shares the result (all-gather).
Signature: def zero_stage2(gradients, world_size)
gradients: list of world_size arrays, each (grad_size,) — local gradients on each rankworld_size: number of ranksworld_size arrays — globally averaged gradients (same result on all ranks)Stage A — Reduce-Scatter: split the gradient vector into world_size contiguous shards along its only axis. Rank i is the sole rank that averages the i-th shard across all ranks' local gradients (so it sees world_size slices of length shard_size and computes their mean). The last shard takes any remainder if grad_size doesn't divide evenly.
Stage B — All-Gather: every rank publishes its averaged shard; concatenating the per-rank shards in rank order reconstructs the full globally-averaged gradient. Return the same full vector for every rank.
| Stage | What's sharded | Memory reduction | |-------|----------------|-----------------| | 1 | Optimizer states | ~4× | | 2 | + Gradients | ~8× | | 3 | + Parameters | ~64× |
ZeRO Stage 2 is used by default in DeepSpeed for multi-GPU training.
Asked at
import numpy as np
def zero_stage2(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?