TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

237. Mish (PyTorch)

Easy

Implement Mish in PyTorch using primitive tensor ops only.

Signature: def mish(x: torch.Tensor) -> torch.Tensor

The rule: you may NOT call F.mish, nn.Mish, F.softplus, torch.tanh, or any high-level activation. Implement the chain yourself.

Allowed primitives: .exp(), .log1p(), .tanh() is also forbidden — implement tanh via (.exp() - .neg().exp()) / (.exp() + .neg().exp()) or use the identity tanh(z) = 1 - 2/(1 + exp(2z)) for stability.

Wait — actually, since composing these manually gets unwieldy, the rule for this problem is: you may use .tanh() (it's fundamental enough), but not F.mish, F.softplus, or nn.Mish. Implement softplus yourself.

Formula:

mish(x) = x * tanh(softplus(x))
softplus(x) = log(1 + exp(x))    — but use log1p for stability

PyTorch idioms vs the NumPy version:

  • NumPy uses np.log1p(np.exp(x)). PyTorch has x.exp().log1p() — same trick, method-chained.
  • For very large x, exp(x) overflows. The classic stable form of softplus is max(x, 0) + log1p(exp(-|x|)). For the test inputs here the naive form works; for production-grade code, use the stable form.

Math

Mish(x)=x⋅tanh(ln(1+ex))

Related problems

  • Mish ActivationeasyNumPy

Asked at

NumPy

import numpy as np

 

def mish(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?