TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

98. ZeRO Stage 3 — Parameter Sharding

Hard

Implement ZeRO Stage 3 parameter update: each rank owns a shard of the model parameters and updates only its shard using the locally computed (averaged) gradient shard.

Signature: def zero_stage3(params_shards, gradients_shards, lr)

  • params_shards: list of arrays, one per rank — each rank's parameter shard
  • gradients_shards: list of arrays, one per rank — averaged gradient shard (from reduce-scatter)
  • lr: learning rate (float)
  • Returns: list of updated parameter shards (same structure as input)

Algorithm

Each rank independently updates its own parameter shard with simple SGD:

for rank in range(world_size):
    params_shards[rank] = params_shards[rank] - lr * gradients_shards[rank]

ZeRO Stage 3 vs FSDP

ZeRO Stage 3 (DeepSpeed) and PyTorch FSDP (Fully Sharded Data Parallel) implement the same core idea: shard parameters, gradients, AND optimizer states across ranks. The all-gather to reconstruct full parameters only happens when needed for forward/backward passes.

Memory Savings

With world_size=64 GPUs, ZeRO Stage 3 reduces per-GPU memory by ~64× for parameters, gradients, and optimizer states — enabling training of 175B+ parameter models.

Asked at

NumPy

import numpy as np

 

def zero_stage3(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?