FSDP (Fully Sharded Data Parallel = ZeRO Stage 3) shards model parameters across workers. Before each forward pass, workers all-gather shards to reconstruct full parameters, run the forward pass, then immediately discard non-owned shards to save memory.
Simulate one FSDP forward step: given shards from all workers, reconstruct the full weight matrix, run a linear forward pass, and return the result.
Signature: def fsdp_forward(x, weight_shards, bias, worker_rank)
x: (batch, d_in) — inputweight_shards: (num_workers, shard_rows, d_in) — W sharded row-wise across workersbias: (d_out,)worker_rank: int — which shard this worker owns (for simulation)(batch, d_out) — linear output (same as if using full weight)The key insight: all-gather reconstructs the full weight, the forward pass is just x @ W_full.T + bias.
Math
Asked at
Test Results