TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

224. NumPy Einsum Intro

Medium

Einsum is a one-line DSL that expresses any combination of multiplication, summation, and broadcasting using named axis labels. Once you can read it, every paper's pseudocode becomes mechanical to translate.

Implement: def three_einsum_translations(a, b) where a has shape (B, M, K) and b has shape (B, K, N). Return a tuple of three arrays computed via np.einsum:

  1. Batched matmul — shape (B, M, N). Pattern: 'bmk,bkn->bmn'.
  2. Sum a over its last axis — shape (B, M). Pattern: 'bmk->bm'.
  3. Sum b over its middle axis — shape (B, N). Pattern: 'bkn->bn'.

Constraint: All three results must come from np.einsum — no @, no np.matmul, no np.sum.

Reading einsum:

  • Letters on the input side label every axis of every input array.
  • Letters on the output side say "keep these".
  • Repeated input letters that are not on the output side are summed (contracted).
  • Letters appearing on both sides are kept (broadcast or batch).

So 'bmk,bkn->bmn' says: a has axes (b, m, k), b has axes (b, k, n); the shared k is summed, b is kept (batched), and m, n are kept. That is exactly batched matmul.

Math

out1​[b,m,n]=k∑​ab,m,k​⋅bb,k,n​, out2​[b,m]=k∑​ab,m,k​, out3​[b,n]=k∑​bb,k,n​

Asked at

NumPy

import numpy as np

 

def three_einsum_translations(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?