Gradient accumulation skips AllReduce on intermediate micro-steps and only syncs every accumulation_steps micro-batches. Compute how many AllReduce calls are saved per epoch.
Signature: def comm_calls_saved(steps_per_epoch: int, accumulation_steps: int) -> int
Without accumulation: steps_per_epoch AllReduce calls.
With accumulation: steps_per_epoch // accumulation_steps calls.
Return: steps_per_epoch - steps_per_epoch // accumulation_steps.
Example:
steps_per_epoch=1000, accumulation_steps=4 → 1000 - 250 = 750 calls saved.Math
Asked at
Test Results