Implement the SwiGLU gating block in PyTorch using primitive tensor ops only.
Signature: def swiglu(x: torch.Tensor, W1: torch.Tensor, W3: torch.Tensor) -> torch.Tensor
The rule: you may NOT call F.silu, nn.SiLU, F.linear, or nn.Linear. Implement the matmuls and the SiLU gate yourself.
Allowed primitives: @ (matmul), .exp(), basic arithmetic.
The block projects the input through W1 and W3 separately, applies SiLU to the W1 branch, and elementwise-multiplies it by the W3 branch. See the math reference below. Implement SiLU yourself in a numerically reasonable way (the standard z · σ(z) form is fine).
Shapes (matching the NumPy version):
x: (d_model,) or (B, d_model) or (B, T, d_model)W1, W3: (d_ff, d_model)(d_ff,)PyTorch idioms vs the NumPy version:
x @ W.T works in both NumPy and PyTorch identically. .T is method-form transpose — for higher-rank tensors prefer .transpose(-1, -2) or .mT.@ operator is torch.matmul, which broadcasts the leading batch dims. So a single implementation handles (d,), (B, d), and (B, T, d) inputs.Math
Related problems
Asked at
import numpy as np
def swiglu(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?