TorchedUp
ProblemsPremium
TorchedUp
Megatron Tensor ParallelismHard
ProblemsPremium

Megatron Tensor Parallelism

Implement Megatron-LM's column-parallel + row-parallel linear layer pattern for splitting a two-layer MLP across GPUs.

Signature: def megatron_mlp(x, W1_shards, W2_shards)

  • x: (batch, d_in) — input (replicated on all ranks)
  • W1_shards: list of world_size matrices, each (d_hidden//world_size, d_in) — column-parallel first layer
  • W2_shards: list of world_size matrices, each (d_out, d_hidden//world_size) — row-parallel second layer
  • Returns: (batch, d_out) — output (after all-reduce across ranks)

Column-Parallel + Row-Parallel Pattern

# Column-parallel (first layer): each rank computes a column shard
#   Rank i: h_i = x @ W1_shards[i].T   → (batch, d_hidden//world_size)

# Row-parallel (second layer): each rank computes partial output
#   Rank i: out_i = h_i @ W2_shards[i].T  → (batch, d_out)

# All-reduce: sum partial outputs
#   output = sum(out_i for i in range(world_size))

This is exactly equivalent to:

h = x @ W1.T  where W1 = vstack(W1_shards)   # full column-parallel
output = h @ W2.T  where W2 = hstack(W2_shards)  # full row-parallel

Why This Matters

For a Transformer FFN with d_hidden=16384 split across 8 GPUs: each GPU only needs 16384/8=2048 hidden units. Communication is a single all-reduce per layer (just output summation).

Asked at

Python (numpy)0/3 runs today

Test Results

○2 ranks: 1×4 input, W1 splits 4→4 hidden (2 per rank), W2 maps 2→4 output per rank
○single rank (no sharding): standard matmul
○result equals full unsplit matmul🔒 Premium
Advertisement