Mistral uses sliding window attention: each token attends only to the W most recent tokens (including itself). This is O(N·W) instead of O(N²), enabling efficient inference over long contexts. Combined with a rolling KV cache, it supports effectively infinite context at fixed memory cost.
Implement causal sliding window attention: token i can attend to positions max(0, i-W+1) through i inclusive (no future tokens, and no tokens older than W-1 steps back). Within that window, run the usual scaled dot-product softmax attention.
Signature: def sliding_window_attention(Q, K, V, window_size)
Q, K, V: (N, d_k)window_size: WHint: when W ≥ N this reduces to standard causal attention.
Math
Asked at
import numpy as np
def sliding_window_attention(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?