TorchedUp
ProblemsPremium
TorchedUp
Gradient CheckpointingMedium
ProblemsPremium

Gradient Checkpointing

Training deep networks requires storing all intermediate activations during the forward pass for use in backpropagation. For an N-layer network this is O(N) memory — a bottleneck for very deep models.

Gradient checkpointing trades compute for memory: only save activations at every K-th layer ("checkpoints"). During backward, recompute the missing intermediate activations from the nearest checkpoint just before they're needed.

  • Standard memory: O(N)
  • Checkpointing memory: O(N/K) saved + O(K) recomputed window = O(√N) when K = √N
  • Compute overhead: roughly 1 extra forward pass total

Implement a simplified version with N linear + tanh layers:

  1. Forward: run all N layers, only store activations at layers 0, K, 2K, … (plus input)
  2. Backward: for each layer i (reversed), find the nearest checkpoint ≤ i, recompute forward from that checkpoint to layer i, then backprop through tanh and the linear layer
  3. Return the gradient of mean(final_output) w.r.t. the input x

Each layer: h = tanh(W @ h_prev)

Signature: def gradient_checkpointing_backward(x, weights, checkpoint_every=2)

  • x: (d,) input vector
  • weights: list of (d, d) weight matrices, one per layer
  • checkpoint_every: save every K-th layer activation
  • Returns: (d,) gradient w.r.t. x

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○4-layer network, d=3
○2-layer network, d=2
○4-layer network, checkpoint_every=1 (store all)🔒 Premium
Advertisement