Einsum is a one-line DSL that expresses any combination of multiplication, summation, and broadcasting using named axis labels. Once you can read it, every paper's pseudocode becomes mechanical to translate.
Implement: def three_einsum_translations(a, b) where a has shape (B, M, K) and b has shape (B, K, N). Return a tuple of three arrays computed via np.einsum:
(B, M, N). Pattern: 'bmk,bkn->bmn'.a over its last axis — shape (B, M). Pattern: 'bmk->bm'.b over its middle axis — shape (B, N). Pattern: 'bkn->bn'.Constraint: All three results must come from np.einsum — no @, no np.matmul, no np.sum.
Reading einsum:
So 'bmk,bkn->bmn' says: a has axes (b, m, k), b has axes (b, k, n); the shared k is summed, b is kept (batched), and m, n are kept. That is exactly batched matmul.
Math
Asked at
import numpy as np
def three_einsum_translations(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?