A naive LayerNorm reads the input three times (mean pass, variance pass, normalization pass). A Welford / one-pass implementation fuses mean and variance into a single streaming pass, then normalizes — only two reads of x total.
Signature: def layernorm_pass_savings(N: int, D: int, dtype_bytes: int) -> list
For a (N, D) input:
x: 3 * N * D * dtype_bytesx: 2 * N * D * dtype_bytesReturn [twopass_reads_bytes, onepass_reads_bytes, savings_bytes] (all ints).
Math
Asked at
Test Results