With sqrt(L) gradient checkpointing, you store activations only at sqrt(L) evenly-spaced layers and recompute the rest during backward. Compute the activation memory and FLOPs multiplier.
Signature: def checkpointing_memory_and_flops(activation_bytes_per_layer: int, n_layers: int) -> tuple
Return a tuple (memory_bytes, flops_multiplier) where:
memory_bytes = activation_bytes_per_layer * round(sqrt(n_layers))flops_multiplier = 1.33 (constant — Chen et al. show ~33% FLOP overhead from one extra forward pass per segment)Use int(round(math.sqrt(n_layers))) for the integer count of stored checkpoints.
Example: With activation_bytes_per_layer = 1_000_000_000 and n_layers = 16, return (4_000_000_000, 1.33).
Math
Asked at
Test Results