TorchedUp
ProblemsPremium
TorchedUp
Expert Load Balancing LossMedium
ProblemsPremium

Expert Load Balancing Loss

Implement the auxiliary load balancing loss for MoE models (from Switch Transformers / Mixtral). This loss penalizes uneven expert utilization to prevent expert collapse (where most tokens route to the same expert).

Signature: def load_balancing_loss(gate_probs, top_k_indices)

  • gate_probs: (N, n_experts) — softmax gating probabilities for all N tokens
  • top_k_indices: (N, top_k) — which experts were selected per token
  • Returns: scalar loss value (float)

Formula

# Fraction of tokens routed to each expert
f_i = (number of tokens with expert i in top-k) / (N * top_k)    # (n_experts,)

# Mean routing probability to each expert
p_i = mean over tokens of gate_probs[:, i]                         # (n_experts,)

# Load balancing loss
loss = n_experts * sum(f_i * p_i)

The loss is minimized when all experts receive equal load (f_i = 1/n_experts for all i), and equals 1.0 at perfect balance.

Why This Matters

Without this loss, gradient descent tends to route all tokens to the best-performing expert (positive feedback loop). The auxiliary loss adds a penalty proportional to how uneven the routing is. In practice, it's weighted by a small coefficient (e.g., 0.01) added to the main task loss.

Asked at

Python (numpy)0/3 runs today

Test Results

○seed=42 dirichlet probs, N=4, 3 experts, top_k=2
○perfectly balanced: each expert gets exactly 1/n_experts tokens → loss near 1/n_experts * n_experts = 1
○all tokens go to expert 0 (collapse): high loss🔒 Premium
Advertisement