TorchedUp
ProblemsPremium
TorchedUp
Gradient Checkpointing SavingsMedium
ProblemsPremium

Gradient Checkpointing Savings

With sqrt(L) gradient checkpointing, you store activations only at sqrt(L) evenly-spaced layers and recompute the rest during backward. Compute the activation memory and FLOPs multiplier.

Signature: def checkpointing_memory_and_flops(activation_bytes_per_layer: int, n_layers: int) -> tuple

Return a tuple (memory_bytes, flops_multiplier) where:

  • memory_bytes = activation_bytes_per_layer * round(sqrt(n_layers))
  • flops_multiplier = 1.33 (constant — Chen et al. show ~33% FLOP overhead from one extra forward pass per segment)

Use int(round(math.sqrt(n_layers))) for the integer count of stored checkpoints.

Example: With activation_bytes_per_layer = 1_000_000_000 and n_layers = 16, return (4_000_000_000, 1.33).

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○L=16
○L=36
○L=100🔒 Premium
Advertisement