Implement triplet margin loss with squared L2 distances in PyTorch.
Signature: def triplet_loss(anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor, margin: float = 1.0) -> torch.Tensor
The rule: you may NOT call F.triplet_margin_loss, nn.TripletMarginLoss, or torch.cdist. Implement the distances yourself.
Allowed primitives: **, .sum(dim=...), .clamp(min=0), .mean(), basic arithmetic.
For each triplet, compute the squared L2 distances anchor-to-positive and anchor-to-negative (sum over the feature dim only — dim=-1), then apply the hinge max(0, d_ap - d_an + margin) and average over the remaining (batch) axes. See the math reference below.
For 1D inputs (single triplet), mean over a single value gives that value. For 2D batched inputs (B, D), distances are computed per-sample and the loss is mean over the batch.
Note: the NumPy version of this problem sums all elements regardless of rank (treating batch dim as flat). The PyTorch version uses the more idiomatic per-sample reduction (sum(dim=-1) then mean()), matching F.triplet_margin_loss. So expected outputs differ for batched inputs.
Math
Related problems
Asked at
import numpy as np
def triplet_loss(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?