ZeRO Stage 1 (Zero Redundancy Optimizer) shards optimizer states across workers instead of replicating them. Each worker only stores and updates 1/N of the optimizer states, reducing optimizer memory by N×.
Simulate: given a gradient (shared across workers via all-reduce), each worker updates only its shard of parameters using Adam-style optimizer state (m, v). Then all workers all-gather the full updated params.
Signature: def zero_stage1_step(full_grad, param_shards, m_shards, v_shards, worker_rank, num_workers, lr, beta1=0.9, beta2=0.999, eps=1e-8, t=1)
full_grad: (D,) — gradient after all-reduce (same on all workers)param_shards: (D,) — this worker's full param copy (all workers have same)m_shards: (D//num_workers,) — this worker's Adam m shardv_shards: (D//num_workers,) — this worker's Adam v shardworker_rank: int — which shard this worker ownsnum_workers: int(updated_params, new_m, new_v) where:
updated_params: (D,) — full params with this worker's shard updated (other shards unchanged)new_m, new_v: (D//num_workers,) — updated optimizer states for this shardMath
Asked at
Test Results