Implement KL(P || Q) in PyTorch using primitive tensor ops only.
Signature: def kl_divergence(P: torch.Tensor, Q: torch.Tensor) -> torch.Tensor
The rule: you may NOT call F.kl_div or nn.KLDivLoss. Both have an unusual signature (they take log Q and P, in that order, by default) which is exactly the porting confusion this problem highlights.
Allowed primitives: .log(), .clamp(), .sum(), basic arithmetic.
Formula:
KL(P || Q) = sum( P * (log(P) - log(Q)) ) over all elements
Clip Q to a small floor (1e-15) to avoid log(0). Skip positions where P == 0 so that 0 * log(0) is treated as 0 by convention.
The reference sums over all elements (matching the NumPy version), so for batched inputs (B, K) the result is the total summed KL across the batch.
PyTorch idioms vs the NumPy version:
.clamp(min=1e-15) is the PyTorch spelling of np.clip(arr, 1e-15, None).P[mask] (works the same as NumPy).Math
Related problems
Asked at
import numpy as np
def kl_divergence(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?