TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

242. KL Divergence (PyTorch)

Easy

Implement KL(P || Q) in PyTorch using primitive tensor ops only.

Signature: def kl_divergence(P: torch.Tensor, Q: torch.Tensor) -> torch.Tensor

The rule: you may NOT call F.kl_div or nn.KLDivLoss. Both have an unusual signature (they take log Q and P, in that order, by default) which is exactly the porting confusion this problem highlights.

Allowed primitives: .log(), .clamp(), .sum(), basic arithmetic.

Formula:

KL(P || Q) = sum( P * (log(P) - log(Q)) )    over all elements

Clip Q to a small floor (1e-15) to avoid log(0). Skip positions where P == 0 so that 0 * log(0) is treated as 0 by convention.

The reference sums over all elements (matching the NumPy version), so for batched inputs (B, K) the result is the total summed KL across the batch.

PyTorch idioms vs the NumPy version:

  • .clamp(min=1e-15) is the PyTorch spelling of np.clip(arr, 1e-15, None).
  • Use boolean masking P[mask] (works the same as NumPy).

Math

DKL​(P∥Q)=i∑​P(i)logQ(i)P(i)​

Related problems

  • KL DivergenceeasyNumPy

Asked at

NumPy

import numpy as np

 

def kl_divergence(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?