Implement BERT's next sentence prediction (NSP) loss — a binary cross-entropy over the [CLS] logits.
Signature: def nsp_loss(cls_logits: np.ndarray, labels: np.ndarray) -> float
cls_logits: shape (batch, 2)labels: shape (batch,) with values in {0, 1}Return the mean cross-entropy. Use a numerically stable softmax.
Math
Asked at
import numpy as np
def nsp_loss(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?