TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

270. Backprop: LayerNorm (PyTorch)

Hard

Implement LayerNorm (along the last dim) as a torch.autograd.Function with hand-derived 3-branch backward.

The rule: you may NOT call F.layer_norm or nn.LayerNorm.

Forward: y = gamma * (x - mean) / sqrt(var + eps) + beta, mean/var computed over the last dim.

Backward has three contributions to dL/dx: through x directly (via the (x - mean) / sigma factor), through mean, and through var. The clean form (with xhat = (x - mean) / sqrt(var + eps)):

dL/dxhat = grad_output * gamma
dL/dx = (1 / sigma) * (dL/dxhat - dL/dxhat.mean(-1) - xhat * (dL/dxhat * xhat).mean(-1))

dL/dgamma = sum_batch(grad_output * xhat), dL/dbeta = sum_batch(grad_output).

The driver ln_run(mode, x, gamma, beta) dispatches 'forward' | 'grad_x' | 'grad_gamma' | 'grad_beta' | 'gradcheck'. Use eps = 1e-5.

Math

y=γσ2+ϵ​x−μ​+β,μ,σ2 over last dim

Related problems

  • Backprop: LayerNormhardNumPy

Asked at

NumPy

import numpy as np

 

def ln_run(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?