TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

97. ZeRO Stage 2 — Gradient Sharding

Hard

Implement ZeRO Stage 2 gradient averaging: each rank is responsible for averaging only its assigned gradient shard (reduce-scatter), then shares the result (all-gather).

Signature: def zero_stage2(gradients, world_size)

  • gradients: list of world_size arrays, each (grad_size,) — local gradients on each rank
  • world_size: number of ranks
  • Returns: list of world_size arrays — globally averaged gradients (same result on all ranks)

Algorithm: Reduce-Scatter + All-Gather

Stage A — Reduce-Scatter: split the gradient vector into world_size contiguous shards along its only axis. Rank i is the sole rank that averages the i-th shard across all ranks' local gradients (so it sees world_size slices of length shard_size and computes their mean). The last shard takes any remainder if grad_size doesn't divide evenly.

Stage B — All-Gather: every rank publishes its averaged shard; concatenating the per-rank shards in rank order reconstructs the full globally-averaged gradient. Return the same full vector for every rank.

ZeRO Stages

| Stage | What's sharded | Memory reduction | |-------|----------------|-----------------| | 1 | Optimizer states | ~4× | | 2 | + Gradients | ~8× | | 3 | + Parameters | ~64× |

ZeRO Stage 2 is used by default in DeepSpeed for multi-GPU training.

Asked at

NumPy

import numpy as np

 

def zero_stage2(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?