TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

198. BatchNorm: Training / Inference Mode

Medium

Your model trained beautifully — 95% accuracy, smooth loss curve, the works. You ship it. The inference pipeline reports 52% accuracy on the same data the model just memorized.

Welcome to the most-stepped-on rake in machine learning: BatchNorm at inference time.

During training, BN normalizes each feature using the current batch's mean and variance. That works because gradients flow through the stats and the model adapts. At inference, you might pass one sample at a time — its "batch statistics" are nonsense (variance = 0 across a single sample). You need the running statistics accumulated during training.

Your job: implement the forward pass for both modes plus the running-stats EMA update.

Signature:

def batchnorm_forward(x, gamma, beta, running_mean, running_var,
                     momentum=0.1, eps=1e-5, is_training=True):
    """
    Returns: (output, new_running_mean, new_running_var)
    """

Inputs:

  • x: (N, D) batch of D-dim vectors
  • gamma, beta: (D,) learnable affine
  • running_mean, running_var: (D,) accumulated stats from prior train steps
  • momentum: EMA weight on the new batch stats
  • eps: numerical-stability constant
  • is_training: True → use batch stats, update running stats. False → use running stats as-is.

Behavior:

| mode | normalize using | update running stats? | |---|---|---| | train | batch mean/var | yes (new = (1-m)*old + m*batch) | | inference | running mean/var | no (return them unchanged) |

Returns: a tuple (output, new_running_mean, new_running_var). In inference mode, the second and third elements are the input running stats unchanged.

Common gotchas your tests will catch:

  • Using batch stats in inference mode (the original sin)
  • Forgetting eps so a degenerate-variance feature explodes
  • EMA direction reversed (m*old + (1-m)*new instead of (1-m)*old + m*new)
  • Updating running stats during inference mode
  • Using the biased vs unbiased variance — BN convention is biased (divide by N, not N-1) — np.var defaults to biased, fine.

Math

x^=σB2​+ε​x−μB​​,y=γx^+β,μrun​←(1−m)μrun​+mμB​

Asked at

NumPy

import numpy as np

 

def batchnorm_forward(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?