TorchedUp
ProblemsPremium
TorchedUp
Activation Memory (Transformer)Medium
ProblemsPremium

Activation Memory (Transformer)

During the forward pass, every layer stores activations needed for the backward pass. Estimate the activation memory in bytes for a Transformer.

Signature: def activation_memory_bytes(batch: int, seq_len: int, n_layers: int, d_model: int, n_heads: int, dtype_bytes: int = 2) -> int

Use this simplified per-layer estimate (a common back-of-the-envelope shortcut):

per_layer_bytes = 16 * batch * seq_len * d_model * dtype_bytes
total_bytes = n_layers * per_layer_bytes

The constant 16 is a rough average over: input residual, attention QK output, attention probs, value output, FFN intermediate (4d hidden), layer norms, and a few small buffers. n_heads is accepted in the signature but not used in this simplified formula — the real Selectie-Korthikanti formula adds a term scaling like n_heads * seq_len / d for attention scores.

Example: B=2, S=2048, L=12, d=768, fp16 -> 12 * 16 * 2 * 2048 * 768 * 2 = 1,207,959,552 bytes (~1.2 GB).

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○small B=2 S=2048
○tiny
○medium🔒 Premium
Advertisement