TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

155. DPO Loss

Medium

Implement the DPO loss. Given chosen/rejected log-probs from both the policy and the frozen reference model:

L = -mean( log_sigmoid( beta * ((lp_c_pol - lp_c_ref) - (lp_r_pol - lp_r_ref)) ) )

Signature: def dpo_loss(logp_chosen_policy: np.ndarray, logp_rejected_policy: np.ndarray, logp_chosen_ref: np.ndarray, logp_rejected_ref: np.ndarray, beta: float = 0.1) -> float

Use a numerically stable log_sigmoid:

log_sigmoid(x) = -log(1 + exp(-x))      if x >= 0
              =  x - log(1 + exp(x))    if x <  0

Return a Python float.

Math

LDPO​=−E[logσ(β(logπref​(yc​∣x)π(yc​∣x)​−logπref​(yr​∣x)π(yr​∣x)​))]

Related problems

  • DPO Loss (PyTorch)mediumPyTorch

Asked at

NumPy

import numpy as np

 

def dpo_loss(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?