Given a model and a fleet, return a basic recommendation for how to split the work: tensor parallel (TP), data parallel (DP), and ZeRO stage.
Signature: def recommend_parallelism(n_params: int, n_gpus: int, gpu_vram_gb: int, training: bool) -> dict
Return a dict with keys 'tp', 'dp', 'zero_stage' (the latter is 0 for inference).
Heuristic:
Inference (training=False):
tp in {1, 2, 4, 8} that satisfies weight_bytes / tp <= gpu_vram_gb * 1e9. If none works, pick tp = 8.tp must also divide n_gpus; if not, fall back to tp = min(8, n_gpus).dp = n_gpus // tp.zero_stage = 0.Training (training=True):
tp = min(8, n_gpus).dp = n_gpus // tp.16 * N / tp. If that fits in gpu_vram_gb * 1e9 after leaving 8 GB headroom for activations, use zero_stage = 2; otherwise zero_stage = 3.Return the dict.
Math
Asked at
Test Results