Given a list of per-token negative log-likelihoods (natural log), compute the perplexity of a language model's predictions.
Signature: def perplexity(neg_log_likelihoods: list) -> float
Perplexity is the exponential of the average negative log-likelihood (see the math reference). The output is a single Python float. Reduce over all elements of the input — for 2D inputs of shape (B, T) the mean is over both batch and time axes, not just one.
Note: NLLs are assumed to use the natural logarithm (base e).
Math
Asked at
import numpy as np
def perplexity(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?