Implement the GRPO (Group Relative Policy Optimization) loss in PyTorch.
Signature: def grpo_loss(rewards: torch.Tensor, log_probs: torch.Tensor, beta_kl: float, kl: torch.Tensor) -> torch.Tensor
The rule: you may NOT call any high-level loss wrapper. Implement the z-score and the policy-gradient surrogate yourself.
Allowed primitives: .mean(), .std(...), basic arithmetic.
Formula:
advantages = (rewards - rewards.mean()) / (rewards.std(unbiased=False) + 1e-8)
loss = -mean(advantages * log_probs) + beta_kl * mean(kl)
Critical PyTorch detail: PyTorch's .std() defaults to unbiased=True (Bessel-corrected, divisor n-1). NumPy's .std() defaults to ddof=0 (divisor n). To match the NumPy reference and the standard GRPO definition, you must pass unbiased=False. This is the most common GRPO porting bug.
Math
Related problems
Asked at
import numpy as np
def grpo_loss(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?