Implement Megatron-LM's column-parallel + row-parallel linear layer pattern for splitting a two-layer MLP across GPUs.
Signature: def megatron_mlp(x, W1_shards, W2_shards)
x: (batch, d_in) — input (replicated on all ranks)W1_shards: list of world_size matrices, each (d_hidden//world_size, d_in) — column-parallel first layerW2_shards: list of world_size matrices, each (d_out, d_hidden//world_size) — row-parallel second layer(batch, d_out) — output (after all-reduce across ranks)# Column-parallel (first layer): each rank computes a column shard
# Rank i: h_i = x @ W1_shards[i].T → (batch, d_hidden//world_size)
# Row-parallel (second layer): each rank computes partial output
# Rank i: out_i = h_i @ W2_shards[i].T → (batch, d_out)
# All-reduce: sum partial outputs
# output = sum(out_i for i in range(world_size))
This is exactly equivalent to:
h = x @ W1.T where W1 = vstack(W1_shards) # full column-parallel
output = h @ W2.T where W2 = hstack(W2_shards) # full row-parallel
For a Transformer FFN with d_hidden=16384 split across 8 GPUs: each GPU only needs 16384/8=2048 hidden units. Communication is a single all-reduce per layer (just output summation).
Asked at
Test Results