For a transformer with N non-embedding parameters, the approximate FLOPs to do one forward + backward pass on one token is 6 * N (Kaplan et al. 2020).
Signature: def flops_per_token(n_params: int) -> int
Return 6 * n_params.
Math
Asked at
import numpy as np
def flops_per_token(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?