Prepare activation and competition functions for TorchScript
This commit is contained in:
parent
900955d67a
commit
bde408a80e
@ -5,30 +5,36 @@ import torch
|
|||||||
ACTIVATIONS = dict()
|
ACTIVATIONS = dict()
|
||||||
|
|
||||||
|
|
||||||
def register_activation(func):
|
# def register_activation(scriptf):
|
||||||
ACTIVATIONS[func.__name__] = func
|
# ACTIVATIONS[scriptf.name] = scriptf
|
||||||
return func
|
# return scriptf
|
||||||
|
def register_activation(f):
|
||||||
|
ACTIVATIONS[f.__name__] = f
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
def identity(input, **kwargs):
|
# @torch.jit.script
|
||||||
|
def identity(input, beta=torch.tensor([0])):
|
||||||
""":math:`f(x) = x`"""
|
""":math:`f(x) = x`"""
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
def sigmoid_beta(input, beta=10):
|
# @torch.jit.script
|
||||||
|
def sigmoid_beta(input, beta=torch.tensor([10])):
|
||||||
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
|
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
beta (float): Parameter :math:`\\beta`
|
beta (float): Parameter :math:`\\beta`
|
||||||
"""
|
"""
|
||||||
out = torch.reciprocal(1.0 + torch.exp(-beta * input))
|
out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * input))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
def swish_beta(input, beta=10):
|
# @torch.jit.script
|
||||||
|
def swish_beta(input, beta=torch.tensor([10])):
|
||||||
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
|
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
|
@ -3,13 +3,15 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
def wtac(distances, labels):
|
def wtac(distances, labels):
|
||||||
winning_indices = torch.min(distances, dim=1).indices
|
winning_indices = torch.min(distances, dim=1).indices
|
||||||
winning_labels = labels[winning_indices].squeeze()
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
return winning_labels
|
return winning_labels
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
def knnc(distances, labels, k):
|
def knnc(distances, labels, k):
|
||||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
|
||||||
winning_labels = labels[winning_indices].squeeze()
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
return winning_labels
|
return winning_labels
|
||||||
|
Loading…
Reference in New Issue
Block a user