Implement a token embedding lookup as a torch.autograd.Function. Forward indexes W by integer ids. Backward scatters incoming gradients back to those rows, summing where the same id appears multiple times.
The rule: you may NOT call F.embedding or nn.Embedding. Use indexing for forward and index_add_ (or equivalent scatter) for backward.
Forward: y = W[idx] where W: (V, D) and idx is integer-typed of arbitrary shape S. Output shape: S + (D,).
Backward: grad_W[v] = sum over occurrences of v in idx of grad_output[at_that_position]. Other rows are zero.
The driver emb_run(mode, W, idx) dispatches 'forward' | 'grad_W' | 'gradcheck'.
Math
Related problems
Asked at
import numpy as np
def emb_run(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?