Update functions/activations.py

This commit is contained in:
Jensun Ravichandran 2021-05-04 20:55:49 +02:00
parent 466e9bde6b
commit b06ded683d

View File

@ -16,40 +16,43 @@ def register_activation(function):
@register_activation
# @torch.jit.script
def identity(x, beta=torch.tensor(0)):
def identity(x, beta=0.0):
"""Identity activation function.
Definition:
:math:`f(x) = x`
Keyword Arguments:
beta (`float`): Ignored.
"""
return x
@register_activation
# @torch.jit.script
def sigmoid_beta(x, beta=torch.tensor(10)):
def sigmoid_beta(x, beta=10.0):
r"""Sigmoid activation function with scaling.
Definition:
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
Keyword Arguments:
beta (`torch.tensor`): Scaling parameter :math:`\beta`
beta (`float`): Scaling parameter :math:`\beta`
"""
out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * x))
out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x))
return out
@register_activation
# @torch.jit.script
def swish_beta(x, beta=torch.tensor(10)):
def swish_beta(x, beta=10.0):
r"""Swish activation function with scaling.
Definition:
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
Keyword Arguments:
beta (`torch.tensor`): Scaling parameter :math:`\beta`
beta (`float`): Scaling parameter :math:`\beta`
"""
out = x * sigmoid_beta(x, beta=beta)
return out