Implement Megatron-LM's column-parallel + row-parallel linear layer pattern for splitting a two-layer MLP across GPUs.
Signature: def megatron_mlp(x, W1_shards, W2_shards)
x: (batch, d_in) — input (replicated on all ranks)W1_shards: list of world_size matrices, each (d_hidden//world_size, d_in) — column-parallel first layerW2_shards: list of world_size matrices, each (d_out, d_hidden//world_size) — row-parallel second layer(batch, d_out) — output (after all-reduce across ranks)The first linear layer is column-parallel: each rank holds a slice of the hidden dimension and produces its own slice of the hidden activation h from the replicated input. The second linear layer is row-parallel: each rank multiplies its hidden slice by its slice of the second weight matrix to produce a partial output of full (batch, d_out) shape. The final output is the sum of those partial outputs across all ranks (this single all-reduce is the only collective per layer).
Equivalently, this is the same answer as concatenating all W1_shards along the hidden dim and all W2_shards along the same hidden dim and running the unsharded two-layer MLP — the sharding just distributes the work.
For a Transformer FFN with d_hidden=16384 split across 8 GPUs: each GPU only needs 16384/8=2048 hidden units. Communication is a single all-reduce per layer (just output summation).
Asked at
import numpy as np
def megatron_mlp(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?