In prefix tuning we prepend learned key/value vectors to the attention's K and V (but not Q). The base model stays frozen; only the prefix tensors train.
Signature: def prefix_tuning_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, K_prefix: np.ndarray, V_prefix: np.ndarray) -> np.ndarray
Shapes (single head):
Q: (T_q, d_k)K, V: (T_kv, d_k)K_prefix, V_prefix: (P, d_k) — P learned prefix slotsReturns: scaled dot-product attention of Q against the augmented keys/values concat([K_prefix, K]) and concat([V_prefix, V]). Output shape (T_q, d_k).
Math
Asked at
Test Results