In RLHF we penalize the policy for drifting from a reference. With a single sample per token, the simplest unbiased estimator of KL(pi || pi_ref) is:
kl_estimate = mean( log_probs_policy - log_probs_ref )
Signature: def kl_penalty(log_probs_policy: np.ndarray, log_probs_ref: np.ndarray) -> float
Return a Python float.
(This is the "K1" estimator — biased but cheap. Many implementations swap in K3, (r-1) - log(r) where r = pi_ref/pi, for lower variance.)
Math
Asked at
Test Results