Implement ELU and SELU in PyTorch using primitive tensor ops only.
Signature: def elu_selu(x: torch.Tensor, mode: str = 'elu', alpha: float = 1.0) -> torch.Tensor
The rule: you may NOT call F.elu, F.selu, nn.ELU, or nn.SELU. We verify your output matches F.elu(x, alpha) and F.selu(x) respectively.
Allowed primitives: .exp(), torch.where, basic arithmetic.
Formulas:
ELU(x) = x if x > 0
= alpha * (exp(x) - 1) if x <= 0
SELU(x) = scale * x if x > 0
= scale * alpha * (exp(x) - 1) if x <= 0
scale = 1.0507009873554804
alpha = 1.6732631921096593 (SELU's fixed alpha; ignore the function arg in 'selu' mode)
PyTorch idioms vs the NumPy version:
torch.where(cond, a, b) selects element-wise. The condition must be a boolean tensor — x > 0 produces one automatically.torch.where). For x.exp() on positive inputs you'll compute values that are then discarded — this is fine, just be aware that gradients flow through both branches.Math
Related problems
Asked at
import numpy as np
def elu_selu(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?