diff --git a/prototorch/functions/activations.py b/prototorch/functions/activations.py index 8ebf3e6..9f6554a 100644 --- a/prototorch/functions/activations.py +++ b/prototorch/functions/activations.py @@ -5,30 +5,36 @@ import torch ACTIVATIONS = dict() -def register_activation(func): - ACTIVATIONS[func.__name__] = func - return func +# def register_activation(scriptf): +# ACTIVATIONS[scriptf.name] = scriptf +# return scriptf +def register_activation(f): + ACTIVATIONS[f.__name__] = f + return f @register_activation -def identity(input, **kwargs): +# @torch.jit.script +def identity(input, beta=torch.tensor([0])): """:math:`f(x) = x`""" return input @register_activation -def sigmoid_beta(input, beta=10): +# @torch.jit.script +def sigmoid_beta(input, beta=torch.tensor([10])): """:math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}` Keyword Arguments: beta (float): Parameter :math:`\\beta` """ - out = torch.reciprocal(1.0 + torch.exp(-beta * input)) + out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * input)) return out @register_activation -def swish_beta(input, beta=10): +# @torch.jit.script +def swish_beta(input, beta=torch.tensor([10])): """:math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}` Keyword Arguments: diff --git a/prototorch/functions/competitions.py b/prototorch/functions/competitions.py index f5709a2..48cf78c 100644 --- a/prototorch/functions/competitions.py +++ b/prototorch/functions/competitions.py @@ -3,13 +3,15 @@ import torch +# @torch.jit.script def wtac(distances, labels): winning_indices = torch.min(distances, dim=1).indices winning_labels = labels[winning_indices].squeeze() return winning_labels +# @torch.jit.script def knnc(distances, labels, k): - winning_indices = torch.topk(-distances, k=k, dim=1).indices + winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices winning_labels = labels[winning_indices].squeeze() return winning_labels