TorchedUp
ProblemsPremium
TorchedUp
DPO LossMedium
ProblemsPremium

Direct Preference Optimization Loss

Implement the DPO loss. Given chosen/rejected log-probs from both the policy and the frozen reference model:

L = -mean( log_sigmoid( beta * ((lp_c_pol - lp_c_ref) - (lp_r_pol - lp_r_ref)) ) )

Signature: def dpo_loss(logp_chosen_policy: np.ndarray, logp_rejected_policy: np.ndarray, logp_chosen_ref: np.ndarray, logp_rejected_ref: np.ndarray, beta: float = 0.1) -> float

Use a numerically stable log_sigmoid:

log_sigmoid(x) = -log(1 + exp(-x))      if x >= 0
              =  x - log(1 + exp(x))    if x <  0

Return a Python float.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○all zeros -> log(2)
○chosen wins
○batch=2 hidden🔒 Premium
Advertisement