Implement the DPO (Direct Preference Optimization) loss in PyTorch.
Signature: def dpo_loss(logp_chosen_policy: torch.Tensor, logp_rejected_policy: torch.Tensor, logp_chosen_ref: torch.Tensor, logp_rejected_ref: torch.Tensor, beta: float = 0.1) -> torch.Tensor
The rule: you may NOT call F.binary_cross_entropy_with_logits. You may use F.logsigmoid — it's the canonical, numerically-stable building block for DPO and the whole point of the PyTorch version is to use it instead of hand-rolling the piecewise stable log_sigmoid like the NumPy version did.
Allowed primitives: F.logsigmoid, .mean(), basic arithmetic.
Formula:
diff = (lp_chosen_policy - lp_chosen_ref) - (lp_rejected_policy - lp_rejected_ref)
loss = -mean( logsigmoid( beta * diff ) )
PyTorch idioms vs the NumPy version:
F.logsigmoid(x) is one line and numerically stable (it's implemented as -softplus(-x) internally). NumPy has no equivalent — that's why the NumPy reference solution writes the piecewise form by hand.Math
Related problems
Asked at
import numpy as np
def dpo_loss(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?