TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

158. KL Penalty (K1 Estimator)

Easy

In RLHF we penalize the policy for drifting from a reference. With a single sample per token, the simplest unbiased estimator of KL(pi || pi_ref) is:

kl_estimate = mean( log_probs_policy - log_probs_ref )

Signature: def kl_penalty(log_probs_policy: np.ndarray, log_probs_ref: np.ndarray) -> float

Return a Python float.

(This is the "K1" estimator — biased but cheap. Many implementations swap in K3, (r-1) - log(r) where r = pi_ref/pi, for lower variance.)

Math

KLK1​=N1​i∑​(logπ(yi​)−logπref​(yi​))

Asked at

NumPy

import numpy as np

 

def kl_penalty(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?