Implement LayerNorm (along the last dim) as a torch.autograd.Function with hand-derived 3-branch backward.
The rule: you may NOT call F.layer_norm or nn.LayerNorm.
Forward: y = gamma * (x - mean) / sqrt(var + eps) + beta, mean/var computed over the last dim.
Backward has three contributions to dL/dx: through x directly (via the (x - mean) / sigma factor), through mean, and through var. The clean form (with xhat = (x - mean) / sqrt(var + eps)):
dL/dxhat = grad_output * gamma
dL/dx = (1 / sigma) * (dL/dxhat - dL/dxhat.mean(-1) - xhat * (dL/dxhat * xhat).mean(-1))
dL/dgamma = sum_batch(grad_output * xhat), dL/dbeta = sum_batch(grad_output).
The driver ln_run(mode, x, gamma, beta) dispatches 'forward' | 'grad_x' | 'grad_gamma' | 'grad_beta' | 'gradcheck'. Use eps = 1e-5.
Math
Related problems
Asked at
import numpy as np
def ln_run(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?