From bde408a80e7a6c8e9058640245338d417a298c66 Mon Sep 17 00:00:00 2001 From: blackfly Date: Wed, 8 Apr 2020 22:42:56 +0200 Subject: [PATCH] Prepare activation and competition functions for TorchScript --- prototorch/functions/activations.py | 20 +++++++++++++------- prototorch/functions/competitions.py | 4 +++- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/prototorch/functions/activations.py b/prototorch/functions/activations.py index 8ebf3e6..9f6554a 100644 --- a/prototorch/functions/activations.py +++ b/prototorch/functions/activations.py @@ -5,30 +5,36 @@ import torch ACTIVATIONS = dict() -def register_activation(func): - ACTIVATIONS[func.__name__] = func - return func +# def register_activation(scriptf): +# ACTIVATIONS[scriptf.name] = scriptf +# return scriptf +def register_activation(f): + ACTIVATIONS[f.__name__] = f + return f @register_activation -def identity(input, **kwargs): +# @torch.jit.script +def identity(input, beta=torch.tensor([0])): """:math:`f(x) = x`""" return input @register_activation -def sigmoid_beta(input, beta=10): +# @torch.jit.script +def sigmoid_beta(input, beta=torch.tensor([10])): """:math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}` Keyword Arguments: beta (float): Parameter :math:`\\beta` """ - out = torch.reciprocal(1.0 + torch.exp(-beta * input)) + out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * input)) return out @register_activation -def swish_beta(input, beta=10): +# @torch.jit.script +def swish_beta(input, beta=torch.tensor([10])): """:math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}` Keyword Arguments: diff --git a/prototorch/functions/competitions.py b/prototorch/functions/competitions.py index f5709a2..48cf78c 100644 --- a/prototorch/functions/competitions.py +++ b/prototorch/functions/competitions.py @@ -3,13 +3,15 @@ import torch +# @torch.jit.script def wtac(distances, labels): winning_indices = torch.min(distances, dim=1).indices winning_labels = labels[winning_indices].squeeze() return winning_labels +# @torch.jit.script def knnc(distances, labels, k): - winning_indices = torch.topk(-distances, k=k, dim=1).indices + winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices winning_labels = labels[winning_indices].squeeze() return winning_labels