For full-fp32 training with the Adam optimizer, compute the bytes of memory needed for parameters + gradients + optimizer state.
Signature: def adam_training_memory_bytes(n_params: int) -> int
Adam tracks four fp32 tensors per parameter:
m (4 bytes)v (4 bytes)Total: 16 bytes per parameter.
Example: 1B params -> 16 GB just for parameters/grads/state (excludes activations).
Return n_params * 16.
Math
Asked at
Test Results