TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

244. Triplet Margin Loss (PyTorch)

Easy

Implement triplet margin loss with squared L2 distances in PyTorch.

Signature: def triplet_loss(anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor, margin: float = 1.0) -> torch.Tensor

The rule: you may NOT call F.triplet_margin_loss, nn.TripletMarginLoss, or torch.cdist. Implement the distances yourself.

Allowed primitives: **, .sum(dim=...), .clamp(min=0), .mean(), basic arithmetic.

For each triplet, compute the squared L2 distances anchor-to-positive and anchor-to-negative (sum over the feature dim only — dim=-1), then apply the hinge max(0, d_ap - d_an + margin) and average over the remaining (batch) axes. See the math reference below.

For 1D inputs (single triplet), mean over a single value gives that value. For 2D batched inputs (B, D), distances are computed per-sample and the loss is mean over the batch.

Note: the NumPy version of this problem sums all elements regardless of rank (treating batch dim as flat). The PyTorch version uses the more idiomatic per-sample reduction (sum(dim=-1) then mean()), matching F.triplet_margin_loss. So expected outputs differ for batched inputs.

Math

L=B1​i=1∑B​max(0,∥ai​−pi​∥22​−∥ai​−ni​∥22​+m)

Related problems

  • Triplet Margin LosseasyNumPy

Asked at

NumPy

import numpy as np

 

def triplet_loss(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?