TorchedUp
ProblemsPremium
TorchedUp
LayerNorm: One-pass vs Two-passMedium
ProblemsPremium

LayerNorm: One-pass vs Two-pass

A naive LayerNorm reads the input three times (mean pass, variance pass, normalization pass). A Welford / one-pass implementation fuses mean and variance into a single streaming pass, then normalizes — only two reads of x total.

Signature: def layernorm_pass_savings(N: int, D: int, dtype_bytes: int) -> list

For a (N, D) input:

  • Two-pass total reads of x: 3 * N * D * dtype_bytes
  • One-pass total reads of x: 2 * N * D * dtype_bytes

Return [twopass_reads_bytes, onepass_reads_bytes, savings_bytes] (all ints).

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○small fp32
○fp16 hidden state
○large bf16🔒 Premium
Advertisement