Gradient accumulation skips AllReduce on intermediate micro-steps and only syncs every accumulation_steps micro-batches. Compute how many AllReduce calls are saved per epoch.
Signature: def comm_calls_saved(steps_per_epoch: int, accumulation_steps: int) -> int
Without accumulation: steps_per_epoch AllReduce calls.
With accumulation: steps_per_epoch // accumulation_steps calls.
Return: steps_per_epoch - steps_per_epoch // accumulation_steps.
Example:
steps_per_epoch=1000, accumulation_steps=4 → 1000 - 250 = 750 calls saved.Math
Asked at
import numpy as np
def comm_calls_saved(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?