TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

140. Gradient Accumulation Comm Savings

Easy

Gradient accumulation skips AllReduce on intermediate micro-steps and only syncs every accumulation_steps micro-batches. Compute how many AllReduce calls are saved per epoch.

Signature: def comm_calls_saved(steps_per_epoch: int, accumulation_steps: int) -> int

Without accumulation: steps_per_epoch AllReduce calls. With accumulation: steps_per_epoch // accumulation_steps calls.

Return: steps_per_epoch - steps_per_epoch // accumulation_steps.

Example:

  • steps_per_epoch=1000, accumulation_steps=4 → 1000 - 250 = 750 calls saved.

Math

saved=S−⌊S/A⌋

Asked at

NumPy

import numpy as np

 

def comm_calls_saved(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?