TorchedUp
ProblemsPremium
TorchedUp
Bucketed AllReduce Bucket CountMedium
ProblemsPremium

Bucketed AllReduce — Optimal Bucket Count

DDP overlaps gradient AllReduce with backward compute by bucketing gradients. Compute the optimal number of buckets that minimizes total step time.

Signature: def optimal_bucket_count(grad_bytes: int, latency_per_call_us: float, bandwidth_bytes_per_us: float) -> int

Model: total time = latency * num_buckets + grad_bytes / bandwidth. Larger buckets → fewer calls but block compute longer. Use the heuristic

bucket_size = sqrt(grad_bytes * bandwidth * latency_per_call_us)

then num_buckets = ceil(grad_bytes / bucket_size).

Example:

  • grad_bytes=10_000_000, latency=10us, bandwidth=1000 B/us → bucket_size = sqrt(1e8) ≈ 10000 → ~1000 buckets

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○small model
○high bandwidth
○tiny gradient🔒 Premium
Advertisement