TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

99. Megatron Tensor Parallelism

Hard

Implement Megatron-LM's column-parallel + row-parallel linear layer pattern for splitting a two-layer MLP across GPUs.

Signature: def megatron_mlp(x, W1_shards, W2_shards)

  • x: (batch, d_in) — input (replicated on all ranks)
  • W1_shards: list of world_size matrices, each (d_hidden//world_size, d_in) — column-parallel first layer
  • W2_shards: list of world_size matrices, each (d_out, d_hidden//world_size) — row-parallel second layer
  • Returns: (batch, d_out) — output (after all-reduce across ranks)

Column-Parallel + Row-Parallel Pattern

The first linear layer is column-parallel: each rank holds a slice of the hidden dimension and produces its own slice of the hidden activation h from the replicated input. The second linear layer is row-parallel: each rank multiplies its hidden slice by its slice of the second weight matrix to produce a partial output of full (batch, d_out) shape. The final output is the sum of those partial outputs across all ranks (this single all-reduce is the only collective per layer).

Equivalently, this is the same answer as concatenating all W1_shards along the hidden dim and all W2_shards along the same hidden dim and running the unsharded two-layer MLP — the sharding just distributes the work.

Why This Matters

For a Transformer FFN with d_hidden=16384 split across 8 GPUs: each GPU only needs 16384/8=2048 hidden units. Communication is a single all-reduce per layer (just output summation).

Asked at

NumPy

import numpy as np

 

def megatron_mlp(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?