Implement the forward pass of a Mixture of Experts (MoE) layer as used in Mixtral, GPT-4, and DeepSeek-V3. Each token is routed to the top-k experts, their outputs are weighted-summed.
Signature: def moe_forward(x, gate_W, expert_Ws, top_k=2)
x: (N, d) — N tokens, each of dimension dgate_W: (n_experts, d) — gating network weightsexpert_Ws: list of n_experts weight matrices, each (d, d) — each expert is a single linear layertop_k: number of experts per token(N, d) — output tokens# 1. Compute gate logits and softmax probabilities
gate_logits = x @ gate_W.T # (N, n_experts)
gate_probs = softmax(gate_logits) # (N, n_experts)
# 2. Select top-k experts per token
top_k_idx = argsort(gate_probs)[:, -top_k:] # (N, top_k) — highest probs
# 3. Renormalize selected weights (so top-k weights sum to 1)
selected_probs = gate_probs[top_k_idx]
weights = selected_probs / selected_probs.sum(-1, keepdims=True)
# 4. Weighted sum of expert outputs
for each token i:
output[i] = sum over top-k: weights[i,ki] * expert_Ws[expert_idx] @ x[i]
Each expert in this problem is a simple linear transformation: expert_out = x[i] @ expert_W.T
Real MoE experts (Mixtral) use FFN (two-layer MLP), but the dispatch/routing logic is identical.
Asked at
Test Results