TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

78. ZeRO Stage 1: Optimizer State Sharding

Medium

ZeRO Stage 1 (Zero Redundancy Optimizer) shards optimizer states across workers instead of replicating them. Each worker only stores and updates 1/N of the optimizer states, reducing optimizer memory by N×.

Simulate: given a gradient (shared across workers via all-reduce), each worker updates only its shard of parameters using Adam-style optimizer state (m, v). Then all workers all-gather the full updated params.

Signature: def zero_stage1_step(full_grad, param_shards, m_shards, v_shards, worker_rank, num_workers, lr, beta1=0.9, beta2=0.999, eps=1e-8, t=1)

  • full_grad: (D,) — gradient after all-reduce (same on all workers)
  • param_shards: (D,) — this worker's full param copy (all workers have same)
  • m_shards: (D//num_workers,) — this worker's Adam m shard
  • v_shards: (D//num_workers,) — this worker's Adam v shard
  • worker_rank: int — which shard this worker owns
  • num_workers: int
  • Returns: (updated_params, new_m, new_v) where:
    • updated_params: (D,) — full params with this worker's shard updated (other shards unchanged)
    • new_m, new_v: (D//num_workers,) — updated optimizer states for this shard

Math

mt​=β1​mt−1​+(1−β1​)gt​,vt​=β2​vt−1​+(1−β2​)gt2​,θ−=v^t​​+ϵηm^t​​

Asked at

NumPy

import numpy as np

 

def zero_stage1_step(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?