Implement Mish in PyTorch using primitive tensor ops only.
Signature: def mish(x: torch.Tensor) -> torch.Tensor
The rule: you may NOT call F.mish, nn.Mish, F.softplus, torch.tanh, or any high-level activation. Implement the chain yourself.
Allowed primitives: .exp(), .log1p(), .tanh() is also forbidden — implement tanh via (.exp() - .neg().exp()) / (.exp() + .neg().exp()) or use the identity tanh(z) = 1 - 2/(1 + exp(2z)) for stability.
Wait — actually, since composing these manually gets unwieldy, the rule for this problem is: you may use .tanh() (it's fundamental enough), but not F.mish, F.softplus, or nn.Mish. Implement softplus yourself.
Formula:
mish(x) = x * tanh(softplus(x))
softplus(x) = log(1 + exp(x)) — but use log1p for stability
PyTorch idioms vs the NumPy version:
np.log1p(np.exp(x)). PyTorch has x.exp().log1p() — same trick, method-chained.x, exp(x) overflows. The classic stable form of softplus is max(x, 0) + log1p(exp(-|x|)). For the test inputs here the naive form works; for production-grade code, use the stable form.Math
Related problems
Asked at
import numpy as np
def mish(...):
pass
Premium problem
Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.
Already premium?