TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

230. FLOPs per Token (Batched)

Medium

Vectorize the 6N rule for FLOPs per token across a batch of model configs. Use the non-embedding parameter count, which for standard GPT-style models is 12 * d_model^2 * n_layers.

Implement: def flops_per_token_batched(d_model, n_layers) where both are 1-D integer arrays of shape (N,). Return shape (N,) of int64.

Per config: the non-embedding parameter count is 12 · d_model² · n_layers, and the per-token FLOPs (forward+backward) follow the 6N rule. See the math reference below for the closed form.

Heads-up: for frontier-scale configs, d_model² · n_layers exceeds the int32 range — cast inputs to int64 before multiplying so the result doesn't overflow.

Where 6 comes from: the rule of thumb is "1 forward pass = 2P FLOPs per parameter" (one multiply + one add), and "1 backward pass = 4P FLOPs" (one for input gradient, one for weight gradient, each costing 2P). Total: 6P per token per parameter, or 6N per token when you sum over all parameters.

Why exclude embeddings? The 6N rule assumes every parameter participates in the matmul-heavy forward+backward. Embedding lookup is O(d_model) per token regardless of vocab size; including embedding params would massively overcount FLOPs for small models.

Math

Fi​=6⋅Ni​=6⋅12⋅dmodel,i2​⋅Li​

Asked at

NumPy

import numpy as np

 

def flops_per_token_batched(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?