TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

271. Backprop: BatchNorm (PyTorch)

Hard

Implement BatchNorm-1d (training mode, normalizes over the batch dim per feature) as a torch.autograd.Function. x shape: (B, F). Same 3-branch backward as LayerNorm, but reduction is over the batch dim instead of features.

The rule: you may NOT call F.batch_norm or nn.BatchNorm1d.

Forward: mean/var over dim=0; xhat = (x - mean) / sqrt(var + eps); y = xhat * gamma + beta.

Backward (with xhat, sigma, gamma saved):

dxhat = grad_output * gamma
grad_x = (1 / sigma) * (dxhat - dxhat.mean(0) - xhat * (dxhat * xhat).mean(0))
grad_gamma = (grad_output * xhat).sum(0)
grad_beta = grad_output.sum(0)

The driver bn_run(mode, x, gamma, beta) dispatches 'forward' | 'grad_x' | 'grad_gamma' | 'grad_beta' | 'gradcheck'. Use eps = 1e-5 and biased variance (unbiased=False).

Math

yb,f​=γf​σf2​+ϵ​xb,f​−μf​​+βf​,μf​,σf2​ over batch

Related problems

  • Backprop: BatchNorm (train mode)hardNumPy

Asked at

NumPy

import numpy as np

 

def bn_run(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?