TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

239. SwiGLU (PyTorch)

Medium

Implement the SwiGLU gating block in PyTorch using primitive tensor ops only.

Signature: def swiglu(x: torch.Tensor, W1: torch.Tensor, W3: torch.Tensor) -> torch.Tensor

The rule: you may NOT call F.silu, nn.SiLU, F.linear, or nn.Linear. Implement the matmuls and the SiLU gate yourself.

Allowed primitives: @ (matmul), .exp(), basic arithmetic.

The block projects the input through W1 and W3 separately, applies SiLU to the W1 branch, and elementwise-multiplies it by the W3 branch. See the math reference below. Implement SiLU yourself in a numerically reasonable way (the standard z · σ(z) form is fine).

Shapes (matching the NumPy version):

  • x: (d_model,) or (B, d_model) or (B, T, d_model)
  • W1, W3: (d_ff, d_model)
  • Returns: matching batch shape × (d_ff,)

PyTorch idioms vs the NumPy version:

  • x @ W.T works in both NumPy and PyTorch identically. .T is method-form transpose — for higher-rank tensors prefer .transpose(-1, -2) or .mT.
  • The PyTorch @ operator is torch.matmul, which broadcasts the leading batch dims. So a single implementation handles (d,), (B, d), and (B, T, d) inputs.

Math

SwiGLU(x,W1​,W3​)=SiLU(W1​x)⊙(W3​x),SiLU(z)=z⋅σ(z)

Related problems

  • SwiGLU (Gated Linear Unit)mediumNumPy

Asked at

NumPy

import numpy as np

 

def swiglu(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?