Training deep networks requires storing all intermediate activations during the forward pass for use in backpropagation. For an N-layer network this is O(N) memory — a bottleneck for very deep models.
Gradient checkpointing trades compute for memory: only save activations at every K-th layer ("checkpoints"). During backward, recompute the missing intermediate activations from the nearest checkpoint just before they're needed.
Implement a simplified version with N linear + tanh layers:
mean(final_output) w.r.t. the input xEach layer: h = tanh(W @ h_prev)
Signature: def gradient_checkpointing_backward(x, weights, checkpoint_every=2)
x: (d,) input vectorweights: list of (d, d) weight matrices, one per layercheckpoint_every: save every K-th layer activationMath
Asked at
Test Results