Implement ZeRO Stage 3 parameter update: each rank owns a shard of the model parameters and updates only its shard using the locally computed (averaged) gradient shard.
Signature: def zero_stage3(params_shards, gradients_shards, lr)
params_shards: list of arrays, one per rank — each rank's parameter shardgradients_shards: list of arrays, one per rank — averaged gradient shard (from reduce-scatter)lr: learning rate (float)Each rank independently updates its own parameter shard with simple SGD:
for rank in range(world_size):
params_shards[rank] = params_shards[rank] - lr * gradients_shards[rank]
ZeRO Stage 3 (DeepSpeed) and PyTorch FSDP (Fully Sharded Data Parallel) implement the same core idea: shard parameters, gradients, AND optimizer states across ranks. The all-gather to reconstruct full parameters only happens when needed for forward/backward passes.
With world_size=64 GPUs, ZeRO Stage 3 reduces per-GPU memory by ~64× for parameters, gradients, and optimizer states — enabling training of 175B+ parameter models.
Asked at
import numpy as np
def zero_stage3(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?