DDP overlaps gradient AllReduce with backward compute by bucketing gradients. Compute the optimal number of buckets that minimizes total step time.
Signature: def optimal_bucket_count(grad_bytes: int, latency_per_call_us: float, bandwidth_bytes_per_us: float) -> int
Model: total time = latency * num_buckets + grad_bytes / bandwidth. Larger buckets → fewer calls but block compute longer. Use the heuristic
bucket_size = sqrt(grad_bytes * bandwidth * latency_per_call_us)
then num_buckets = ceil(grad_bytes / bucket_size).
Example:
grad_bytes=10_000_000, latency=10us, bandwidth=1000 B/us → bucket_size = sqrt(1e8) ≈ 10000 → ~1000 bucketsMath
Asked at
Test Results