For a transformer with N non-embedding parameters, the approximate FLOPs to do one forward + backward pass on one token is 6 * N (Kaplan et al. 2020).
Signature: def flops_per_token(n_params: int) -> int
Return 6 * n_params.
Math
Asked at
Test Results