TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

247. GRPO Loss (PyTorch)

Hard

Implement the GRPO (Group Relative Policy Optimization) loss in PyTorch.

Signature: def grpo_loss(rewards: torch.Tensor, log_probs: torch.Tensor, beta_kl: float, kl: torch.Tensor) -> torch.Tensor

The rule: you may NOT call any high-level loss wrapper. Implement the z-score and the policy-gradient surrogate yourself.

Allowed primitives: .mean(), .std(...), basic arithmetic.

Formula:

advantages = (rewards - rewards.mean()) / (rewards.std(unbiased=False) + 1e-8)
loss       = -mean(advantages * log_probs) + beta_kl * mean(kl)

Critical PyTorch detail: PyTorch's .std() defaults to unbiased=True (Bessel-corrected, divisor n-1). NumPy's .std() defaults to ddof=0 (divisor n). To match the NumPy reference and the standard GRPO definition, you must pass unbiased=False. This is the most common GRPO porting bug.

Math

LGRPO​=−G1​i∑​Ai​logπ(yi​∣x)+βDKL​​,Ai​=σr​+εri​−rˉ​

Related problems

  • GRPO LosshardNumPy

Asked at

NumPy

import numpy as np

 

def grpo_loss(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?