TorchedUp
ProblemsPremium
TorchedUp
FSDP: Sharded ParametersHard
ProblemsPremium

FSDP: Sharded Parameters

FSDP (Fully Sharded Data Parallel = ZeRO Stage 3) shards model parameters across workers. Before each forward pass, workers all-gather shards to reconstruct full parameters, run the forward pass, then immediately discard non-owned shards to save memory.

Simulate one FSDP forward step: given shards from all workers, reconstruct the full weight matrix, run a linear forward pass, and return the result.

Signature: def fsdp_forward(x, weight_shards, bias, worker_rank)

  • x: (batch, d_in) — input
  • weight_shards: (num_workers, shard_rows, d_in) — W sharded row-wise across workers
  • bias: (d_out,)
  • worker_rank: int — which shard this worker owns (for simulation)
  • Returns: (batch, d_out) — linear output (same as if using full weight)

The key insight: all-gather reconstructs the full weight, the forward pass is just x @ W_full.T + bias.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○batch=2, d_in=4, d_out=6, 2 workers (seed 42)
○worker_rank=1 gives same output (all-gather is symmetric)
○batch=3, d_in=6, d_out=8, 4 workers (seed 7)🔒 Premium
Advertisement