Prepare activation and competition functions for TorchScript

This commit is contained in:
blackfly 2020-04-08 22:42:56 +02:00
parent 900955d67a
commit bde408a80e
2 changed files with 16 additions and 8 deletions

View File

@ -5,30 +5,36 @@ import torch
ACTIVATIONS = dict() ACTIVATIONS = dict()
def register_activation(func): # def register_activation(scriptf):
ACTIVATIONS[func.__name__] = func # ACTIVATIONS[scriptf.name] = scriptf
return func # return scriptf
def register_activation(f):
ACTIVATIONS[f.__name__] = f
return f
@register_activation @register_activation
def identity(input, **kwargs): # @torch.jit.script
def identity(input, beta=torch.tensor([0])):
""":math:`f(x) = x`""" """:math:`f(x) = x`"""
return input return input
@register_activation @register_activation
def sigmoid_beta(input, beta=10): # @torch.jit.script
def sigmoid_beta(input, beta=torch.tensor([10])):
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}` """:math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
Keyword Arguments: Keyword Arguments:
beta (float): Parameter :math:`\\beta` beta (float): Parameter :math:`\\beta`
""" """
out = torch.reciprocal(1.0 + torch.exp(-beta * input)) out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * input))
return out return out
@register_activation @register_activation
def swish_beta(input, beta=10): # @torch.jit.script
def swish_beta(input, beta=torch.tensor([10])):
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}` """:math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
Keyword Arguments: Keyword Arguments:

View File

@ -3,13 +3,15 @@
import torch import torch
# @torch.jit.script
def wtac(distances, labels): def wtac(distances, labels):
winning_indices = torch.min(distances, dim=1).indices winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze() winning_labels = labels[winning_indices].squeeze()
return winning_labels return winning_labels
# @torch.jit.script
def knnc(distances, labels, k): def knnc(distances, labels, k):
winning_indices = torch.topk(-distances, k=k, dim=1).indices winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
winning_labels = labels[winning_indices].squeeze() winning_labels = labels[winning_indices].squeeze()
return winning_labels return winning_labels