TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

263. Rotary Position Embedding (PyTorch)

Medium

Implement RoPE in PyTorch. RoPE rotates query/key vectors by position-dependent angles instead of adding a positional embedding — the rotation directly modifies the QK^T dot product to encode relative position.

Signature: def rope(x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor

  • x: (..., seq_len, d) — query or key vectors. d is even.
  • positions: (..., seq_len) — integer (or float) position indices, broadcastable with the leading batch dims of x.
  • Returns: same shape as x.

Pairing convention: the LLaMA / GPT-NeoX layout splits x into halves: the first d/2 and the last d/2 form the rotated pairs. (The original RoPE paper uses interleaved pairs (x_0, x_1), (x_2, x_3), .... The two are equivalent up to a permutation.)

theta_i = 1 / 10000 ** (2i / d),    i = 0, ..., d/2 - 1
[x1', x2'] = [x1 * cos(m*theta) - x2 * sin(m*theta),
              x1 * sin(m*theta) + x2 * cos(m*theta)]

PyTorch idioms vs NumPy:

  • torch.arange(half_d, dtype=torch.float32) — explicit dtype avoids an int64 default that breaks the float division.
  • positions.unsqueeze(-1) * theta is the broadcast-friendly form of NumPy's positions[..., None] * theta.
  • torch.cat([..., ...], dim=-1) instead of np.concatenate([...], axis=-1).
  • x[..., :half_d] slicing works the same as NumPy.

Math

θi​=100002i/d1​,i=0,1,…,d/2−1(x1′​x2′​​)=(cos(mθ)sin(mθ)​−sin(mθ)cos(mθ)​)(x1​x2​​)

Related problems

  • Rotary Position Embedding (RoPE)mediumNumPy

Asked at

NumPy

import numpy as np

 

def rope(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?