Implement Expert Parallelism for MoE models: each rank hosts a subset of experts, and tokens are dispatched to the correct rank via all-to-all communication.
Signature: def expert_parallel_dispatch(tokens, routing, expert_Ws, rank, world_size)
tokens: (N, d) — all tokens (available on all ranks before dispatch)routing: (N,) — expert index assigned to each token (0 to n_experts-1)expert_Ws: list of weight matrices for experts on this rank, each (d, d)rank: this rank's index (0-based)world_size: total number of ranks (each rank hosts n_experts//world_size experts)(N, d) — output tokens for all tokens processed by this rank's experts (others are zero)Each rank owns experts [rank * experts_per_rank, (rank+1) * experts_per_rank).
experts_per_rank = n_experts // world_size
my_expert_start = rank * experts_per_rank
my_expert_end = my_expert_start + experts_per_rank
output = zeros(N, d)
for i, token in enumerate(tokens):
expert_idx = routing[i]
if my_expert_start <= expert_idx < my_expert_end:
local_expert_idx = expert_idx - my_expert_start
output[i] = token @ expert_Ws[local_expert_idx].T
In practice, an all-to-all collective sends each token to its target rank. Here we simulate by processing only the tokens assigned to this rank's experts.
With 64 experts spread across 8 GPUs (8 experts/GPU), each GPU stores only 1/8 of the expert parameters. Communication happens once per MoE layer (all-to-all for tokens), enabling massive MoE scaling.
Asked at
Test Results