Implement the auxiliary load balancing loss for MoE models (from Switch Transformers / Mixtral). This loss penalizes uneven expert utilization to prevent expert collapse (where most tokens route to the same expert).
Signature: def load_balancing_loss(gate_probs, top_k_indices)
gate_probs: (N, n_experts) — softmax gating probabilities for all N tokenstop_k_indices: (N, top_k) — which experts were selected per token# Fraction of tokens routed to each expert
f_i = (number of tokens with expert i in top-k) / (N * top_k) # (n_experts,)
# Mean routing probability to each expert
p_i = mean over tokens of gate_probs[:, i] # (n_experts,)
# Load balancing loss
loss = n_experts * sum(f_i * p_i)
The loss is minimized when all experts receive equal load (f_i = 1/n_experts for all i), and equals 1.0 at perfect balance.
Without this loss, gradient descent tends to route all tokens to the best-performing expert (positive feedback loop). The auxiliary loss adds a penalty proportional to how uneven the routing is. In practice, it's weighted by a small coefficient (e.g., 0.01) added to the main task loss.
Asked at
Test Results