TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

207. Backprop: BatchNorm (train mode)

Hard

Hand-derive the gradient of L = sum(BatchNorm(x)) w.r.t. the input x in training mode (using batch statistics, no running mean/var).

Forward (per feature column, along the batch axis):

  • mu = mean(x, axis=0), var = mean((x - mu)^2, axis=0), std = sqrt(var + eps)
  • x_hat = (x - mu) / std
  • y = gamma * x_hat + beta

Implement:

  • batchnorm_forward(x, gamma, beta, eps=1e-5) -> y of shape (N, F)
  • batchnorm_backward(x, gamma, beta, eps=1e-5) -> dL/dx of shape (N, F)

The math is identical to LayerNorm but transposed: now the reductions go along the batch dim instead of the feature dim. With L = sum(y) so dL/dx_hat = gamma (broadcast across the batch),

dL/dx = (1/std) * (dL/dx_hat - mean_batch(dL/dx_hat) - x_hat * mean_batch(dL/dx_hat * x_hat))

Note: mean(gamma) along the batch axis is just gamma itself, so that term simplifies and effectively the first two terms cancel. Verify what's left.

Math

∂xnj​∂L​=σj​1​(γj​−γj​​(n)​−x^nj​⋅γj​x^nj​​(n)​)

Related problems

  • Backprop: BatchNorm (PyTorch)hardPyTorch

Asked at

NumPy

import numpy as np

 

def batchnorm_forward(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?