Update functions/activations.py

This commit is contained in:
Jensun Ravichandran 2021-05-04 20:55:49 +02:00
parent 466e9bde6b
commit b06ded683d

View File

@ -16,40 +16,43 @@ def register_activation(function):
@register_activation @register_activation
# @torch.jit.script # @torch.jit.script
def identity(x, beta=torch.tensor(0)): def identity(x, beta=0.0):
"""Identity activation function. """Identity activation function.
Definition: Definition:
:math:`f(x) = x` :math:`f(x) = x`
Keyword Arguments:
beta (`float`): Ignored.
""" """
return x return x
@register_activation @register_activation
# @torch.jit.script # @torch.jit.script
def sigmoid_beta(x, beta=torch.tensor(10)): def sigmoid_beta(x, beta=10.0):
r"""Sigmoid activation function with scaling. r"""Sigmoid activation function with scaling.
Definition: Definition:
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}` :math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
Keyword Arguments: Keyword Arguments:
beta (`torch.tensor`): Scaling parameter :math:`\beta` beta (`float`): Scaling parameter :math:`\beta`
""" """
out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * x)) out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x))
return out return out
@register_activation @register_activation
# @torch.jit.script # @torch.jit.script
def swish_beta(x, beta=torch.tensor(10)): def swish_beta(x, beta=10.0):
r"""Swish activation function with scaling. r"""Swish activation function with scaling.
Definition: Definition:
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}` :math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
Keyword Arguments: Keyword Arguments:
beta (`torch.tensor`): Scaling parameter :math:`\beta` beta (`float`): Scaling parameter :math:`\beta`
""" """
out = x * sigmoid_beta(x, beta=beta) out = x * sigmoid_beta(x, beta=beta)
return out return out