TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

118. Tensor-Parallel Memory per Rank

Medium

Tensor parallelism (Megatron-style) shards each weight matrix across tp_size ranks. Compute the weight memory per rank in bytes.

Signature: def tp_memory_per_rank_bytes(n_params: int, tp_size: int, dtype_bytes: int = 2) -> int

Formula:

bytes = (n_params * dtype_bytes) // tp_size

We approximate that the entire parameter count shards evenly. In reality, layer norm and embedding shards are slightly different, but the matmul-dominated weights (>99% of params for decent-sized models) shard cleanly.

Example: A 7B model in fp16 with TP=4 -> 7e9 * 2 / 4 = 3.5e9 bytes per rank.

Math

Mrank​=TPN⋅b​

Asked at

NumPy

import numpy as np

 

def tp_memory_per_rank_bytes(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?