Tensor parallelism (Megatron-style) shards each weight matrix across tp_size ranks. Compute the weight memory per rank in bytes.
Signature: def tp_memory_per_rank_bytes(n_params: int, tp_size: int, dtype_bytes: int = 2) -> int
Formula:
bytes = (n_params * dtype_bytes) // tp_size
We approximate that the entire parameter count shards evenly. In reality, layer norm and embedding shards are slightly different, but the matmul-dominated weights (>99% of params for decent-sized models) shard cleanly.
Example: A 7B model in fp16 with TP=4 -> 7e9 * 2 / 4 = 3.5e9 bytes per rank.
Math
Asked at
import numpy as np
def tp_memory_per_rank_bytes(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?