TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

79. FSDP: Sharded Parameters

Hard

FSDP (Fully Sharded Data Parallel = ZeRO Stage 3) shards model parameters across workers. Before each forward pass, workers all-gather shards to reconstruct full parameters, run the forward pass, then immediately discard non-owned shards to save memory.

Simulate one FSDP forward step: given shards from all workers, reconstruct the full weight matrix, run a linear forward pass, and return the result.

Signature: def fsdp_forward(x, weight_shards, bias, worker_rank)

  • x: (batch, d_in) — input
  • weight_shards: (num_workers, shard_rows, d_in) — W sharded row-wise across workers
  • bias: (d_out,)
  • worker_rank: int — which shard this worker owns (for simulation)
  • Returns: (batch, d_out) — linear output (same as if using full weight)

The key insight: all-gather reconstructs the full weight, the forward pass is just x @ W_full.T + bias.

Math

Wfull​=AllGather(W0​,W1​,…,WN−1​),out=xWfull⊤​+b

Asked at

NumPy

import numpy as np

 

def fsdp_forward(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?