TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

206. Backprop: LayerNorm

Hard

Hand-derive the gradient of L = sum(LayerNorm(x)) w.r.t. the input x.

Forward (per row, along the last axis):

  • mu = mean(x), var = mean((x - mu)^2), std = sqrt(var + eps)
  • x_hat = (x - mu) / std
  • y = gamma * x_hat + beta

Implement:

  • layernorm_forward(x, gamma, beta, eps=1e-5) -> y of the same shape as x
  • layernorm_backward(x, gamma, beta, eps=1e-5) -> dL/dx of the same shape

The gradient has three branches: a direct path through x_hat, a path through the mean mu, and a path through the variance var. They combine into the canonical formula

dL/dx = (1/std) * (dL/dx_hat - mean(dL/dx_hat) - x_hat * mean(dL/dx_hat * x_hat))

where dL/dx_hat = gamma (because L = sum(y) so dL/dy = 1, and y = gamma * x_hat + beta).

Math

∂xj​∂L​=σ1​(γj​−γ​−x^j​⋅γx^​)

Related problems

  • Backprop: LayerNorm (PyTorch)hardPyTorch

Asked at

NumPy

import numpy as np

 

def layernorm_forward(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?