Mistral uses sliding window attention: each token attends only to the W most recent tokens (including itself). This is O(N·W) instead of O(N²), enabling efficient inference over long contexts. Combined with a rolling KV cache, it supports effectively infinite context at fixed memory cost.
Implement causal sliding window attention: token i can attend to positions max(0, i-W+1) through i inclusive.
For each query i:
start = max(0, i - W + 1)
K_local = K[start:i+1]
V_local = V[start:i+1]
scores = Q[i] @ K_local.T / sqrt(d_k)
weights = softmax(scores)
out[i] = weights @ V_local
Signature: def sliding_window_attention(Q, K, V, window_size)
Q, K, V: (N, d_k)window_size: WHint: when W ≥ N this reduces to standard causal attention.
Math
Asked at
Test Results