TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

139. Bucketed AllReduce Bucket Count

Medium

DDP overlaps gradient AllReduce with backward compute by bucketing gradients. Compute the optimal number of buckets that minimizes total step time.

Signature: def optimal_bucket_count(grad_bytes: int, latency_per_call_us: float, bandwidth_bytes_per_us: float) -> int

Model: total time = latency * num_buckets + grad_bytes / bandwidth. Larger buckets → fewer calls but block compute longer. Use the heuristic

bucket_size = sqrt(grad_bytes * bandwidth * latency_per_call_us)

then num_buckets = ceil(grad_bytes / bucket_size).

Example:

  • grad_bytes=10_000_000, latency=10us, bandwidth=1000 B/us → bucket_size = sqrt(1e8) ≈ 10000 → ~1000 buckets

Math

B∗=N⋅BW⋅L​,k∗=⌈N/B∗⌉

Asked at

NumPy

import numpy as np

 

def optimal_bucket_count(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?