TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

71. Gradient Checkpointing

Medium

Training deep networks requires storing all intermediate activations during the forward pass for use in backpropagation. For an N-layer network this is O(N) memory — a bottleneck for very deep models.

Gradient checkpointing trades compute for memory: only save activations at every K-th layer ("checkpoints"). During backward, recompute the missing intermediate activations from the nearest checkpoint just before they're needed.

  • Standard memory: O(N)
  • Checkpointing memory: O(N/K) saved + O(K) recomputed window = O(√N) when K = √N
  • Compute overhead: roughly 1 extra forward pass total

Implement a simplified version with N linear + tanh layers:

  1. Forward: run all N layers, only store activations at layers 0, K, 2K, … (plus input)
  2. Backward: for each layer i (reversed), find the nearest checkpoint ≤ i, recompute forward from that checkpoint to layer i, then backprop through tanh and the linear layer
  3. Return the gradient of mean(final_output) w.r.t. the input x

Each layer: h = tanh(W @ h_prev)

Signature: def gradient_checkpointing_backward(x, weights, checkpoint_every=2)

  • x: (d,) input vector
  • weights: list of (d, d) weight matrices, one per layer
  • checkpoint_every: save every K-th layer activation
  • Returns: (d,) gradient w.r.t. x

Math

hi​=tanh(Wi​hi−1​)∂x∂hˉN​​=i=0∏N−1​WiT​⋅diag(1−hi+12​)

Asked at

NumPy

import numpy as np

 

def gradient_checkpointing_backward(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?