Implement sequence parallelism from Megatron-LM v3: shard the sequence dimension across ranks for non-tensor-parallel regions (LayerNorm, Dropout), reducing activation memory.
Signature: def sequence_parallel_layernorm(x_shards, gamma, beta)
x_shards: list of world_size arrays, each (S//world_size, d) — sequence shardsgamma: (d,) — LayerNorm scalebeta: (d,) — LayerNorm biasworld_size arrays — normalized shards (same structure as input)Each rank independently applies LayerNorm to its sequence shard. Since LayerNorm normalizes over the feature dimension (not sequence), no communication is needed.
for rank in range(world_size):
x_shard = x_shards[rank] # (S//world, d)
mu = x_shard.mean(axis=-1, keepdims=True) # per-token mean
var = x_shard.var(axis=-1, keepdims=True) # per-token variance
out_shard = gamma * (x_shard - mu) / sqrt(var + 1e-5) + beta
In Megatron's full setup:
This problem focuses on the LayerNorm portion of sequence parallelism.
Asked at
Test Results