TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

116. Gradient Checkpointing Savings

Medium

With sqrt(L) gradient checkpointing, you store activations only at sqrt(L) evenly-spaced layers and recompute the rest during backward. Compute the activation memory and FLOPs multiplier.

Signature: def checkpointing_memory_and_flops(activation_bytes_per_layer: int, n_layers: int) -> tuple

Return a tuple (memory_bytes, flops_multiplier) where:

  • memory_bytes = activation_bytes_per_layer * round(sqrt(n_layers))
  • flops_multiplier = 1.33 (constant — Chen et al. show ~33% FLOP overhead from one extra forward pass per segment)

Use int(round(math.sqrt(n_layers))) for the integer count of stored checkpoints.

Example: With activation_bytes_per_layer = 1_000_000_000 and n_layers = 16, return (4_000_000_000, 1.33).

Math

Mckpt​=Mlayer​⋅L​,FLOPs=1.33×baseline

Asked at

NumPy

import numpy as np

 

def checkpointing_memory_and_flops(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?