TorchedUp
ProblemsPremium
TorchedUp
ZeRO Stage 1: Optimizer State ShardingMedium
ProblemsPremium

ZeRO Stage 1: Optimizer State Sharding

ZeRO Stage 1 (Zero Redundancy Optimizer) shards optimizer states across workers instead of replicating them. Each worker only stores and updates 1/N of the optimizer states, reducing optimizer memory by N×.

Simulate: given a gradient (shared across workers via all-reduce), each worker updates only its shard of parameters using Adam-style optimizer state (m, v). Then all workers all-gather the full updated params.

Signature: def zero_stage1_step(full_grad, param_shards, m_shards, v_shards, worker_rank, num_workers, lr, beta1=0.9, beta2=0.999, eps=1e-8, t=1)

  • full_grad: (D,) — gradient after all-reduce (same on all workers)
  • param_shards: (D,) — this worker's full param copy (all workers have same)
  • m_shards: (D//num_workers,) — this worker's Adam m shard
  • v_shards: (D//num_workers,) — this worker's Adam v shard
  • worker_rank: int — which shard this worker owns
  • num_workers: int
  • Returns: (updated_params, new_m, new_v) where:
    • updated_params: (D,) — full params with this worker's shard updated (other shards unchanged)
    • new_m, new_v: (D//num_workers,) — updated optimizer states for this shard

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○Worker 0 of 2, D=8, seed 42
○Worker 1 of 2, D=8, seed 42
○Worker 2 of 4, D=16, seed 7🔒 Premium
Advertisement