TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

245. DPO Loss (PyTorch)

Medium

Implement the DPO (Direct Preference Optimization) loss in PyTorch.

Signature: def dpo_loss(logp_chosen_policy: torch.Tensor, logp_rejected_policy: torch.Tensor, logp_chosen_ref: torch.Tensor, logp_rejected_ref: torch.Tensor, beta: float = 0.1) -> torch.Tensor

The rule: you may NOT call F.binary_cross_entropy_with_logits. You may use F.logsigmoid — it's the canonical, numerically-stable building block for DPO and the whole point of the PyTorch version is to use it instead of hand-rolling the piecewise stable log_sigmoid like the NumPy version did.

Allowed primitives: F.logsigmoid, .mean(), basic arithmetic.

Formula:

diff = (lp_chosen_policy - lp_chosen_ref) - (lp_rejected_policy - lp_rejected_ref)
loss = -mean( logsigmoid( beta * diff ) )

PyTorch idioms vs the NumPy version:

  • F.logsigmoid(x) is one line and numerically stable (it's implemented as -softplus(-x) internally). NumPy has no equivalent — that's why the NumPy reference solution writes the piecewise form by hand.
  • This is one of the few cases where the PyTorch port is cleaner than the NumPy version, not just a syntax translation.

Math

LDPO​=−E[logσ(β((logπθ​(yc​)−logπref​(yc​))−(logπθ​(yr​)−logπref​(yr​))))]

Related problems

  • DPO LossmediumNumPy

Asked at

NumPy

import numpy as np

 

def dpo_loss(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?