TorchedUp
ProblemsPremium
TorchedUp
Gradient Accumulation Comm SavingsEasy
ProblemsPremium

Gradient Accumulation — Comm Calls Saved

Gradient accumulation skips AllReduce on intermediate micro-steps and only syncs every accumulation_steps micro-batches. Compute how many AllReduce calls are saved per epoch.

Signature: def comm_calls_saved(steps_per_epoch: int, accumulation_steps: int) -> int

Without accumulation: steps_per_epoch AllReduce calls. With accumulation: steps_per_epoch // accumulation_steps calls.

Return: steps_per_epoch - steps_per_epoch // accumulation_steps.

Example:

  • steps_per_epoch=1000, accumulation_steps=4 → 1000 - 250 = 750 calls saved.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○basic
○no accumulation
○aggressive accum🔒 Premium
Advertisement