Implement RoPE in PyTorch. RoPE rotates query/key vectors by position-dependent angles instead of adding a positional embedding — the rotation directly modifies the QK^T dot product to encode relative position.
Signature: def rope(x: torch.Tensor, positions: torch.Tensor) -> torch.Tensor
x: (..., seq_len, d) — query or key vectors. d is even.positions: (..., seq_len) — integer (or float) position indices, broadcastable with the leading batch dims of x.x.Pairing convention: the LLaMA / GPT-NeoX layout splits x into halves: the first d/2 and the last d/2 form the rotated pairs. (The original RoPE paper uses interleaved pairs (x_0, x_1), (x_2, x_3), .... The two are equivalent up to a permutation.)
theta_i = 1 / 10000 ** (2i / d), i = 0, ..., d/2 - 1
[x1', x2'] = [x1 * cos(m*theta) - x2 * sin(m*theta),
x1 * sin(m*theta) + x2 * cos(m*theta)]
PyTorch idioms vs NumPy:
torch.arange(half_d, dtype=torch.float32) — explicit dtype avoids an int64 default that breaks the float division.positions.unsqueeze(-1) * theta is the broadcast-friendly form of NumPy's positions[..., None] * theta.torch.cat([..., ...], dim=-1) instead of np.concatenate([...], axis=-1).x[..., :half_d] slicing works the same as NumPy.Math
Related problems
Asked at
import numpy as np
def rope(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?