TorchedUp
ProblemsPremium
TorchedUp
KL Penalty (K1 Estimator)Easy
ProblemsPremium

KL Penalty (K1 Estimator)

In RLHF we penalize the policy for drifting from a reference. With a single sample per token, the simplest unbiased estimator of KL(pi || pi_ref) is:

kl_estimate = mean( log_probs_policy - log_probs_ref )

Signature: def kl_penalty(log_probs_policy: np.ndarray, log_probs_ref: np.ndarray) -> float

Return a Python float.

(This is the "K1" estimator — biased but cheap. Many implementations swap in K3, (r-1) - log(r) where r = pi_ref/pi, for lower variance.)

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○identical -> 0
○policy sharper
○mixed hidden🔒 Premium
Advertisement