TorchedUp
ProblemsPremium
TorchedUp
Optimal Shard Count per NodeMedium
ProblemsPremium

Optimal Shard Count per Node

Given a model that doesn't fit on a single GPU, compute the minimum number of shards per node needed to fit it in VRAM, clamped to the GPUs you actually have.

Signature: def optimal_shards_per_node(model_bytes: int, gpus_per_node: int, gpu_vram_bytes: int) -> int

Formula: max(1, min(gpus_per_node, ceil(model_bytes / gpu_vram_bytes))).

Example:

  • 70B model in fp16 ≈ 140 GB, 8x A100-80GB → ceil(140/80) = 2 shards.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○70B fp16 on 8xA100
○fits on one
○needs all gpus🔒 Premium
Advertisement