Implement the DPO loss. Given chosen/rejected log-probs from both the policy and the frozen reference model:
L = -mean( log_sigmoid( beta * ((lp_c_pol - lp_c_ref) - (lp_r_pol - lp_r_ref)) ) )
Signature: def dpo_loss(logp_chosen_policy: np.ndarray, logp_rejected_policy: np.ndarray, logp_chosen_ref: np.ndarray, logp_rejected_ref: np.ndarray, beta: float = 0.1) -> float
Use a numerically stable log_sigmoid:
log_sigmoid(x) = -log(1 + exp(-x)) if x >= 0
= x - log(1 + exp(x)) if x < 0
Return a Python float.
Math
Asked at
Test Results