Prepare activation and competition functions for TorchScript
This commit is contained in:
		@@ -5,30 +5,36 @@ import torch
 | 
				
			|||||||
ACTIVATIONS = dict()
 | 
					ACTIVATIONS = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def register_activation(func):
 | 
					# def register_activation(scriptf):
 | 
				
			||||||
    ACTIVATIONS[func.__name__] = func
 | 
					#     ACTIVATIONS[scriptf.name] = scriptf
 | 
				
			||||||
    return func
 | 
					#     return scriptf
 | 
				
			||||||
 | 
					def register_activation(f):
 | 
				
			||||||
 | 
					    ACTIVATIONS[f.__name__] = f
 | 
				
			||||||
 | 
					    return f
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@register_activation
 | 
					@register_activation
 | 
				
			||||||
def identity(input, **kwargs):
 | 
					# @torch.jit.script
 | 
				
			||||||
 | 
					def identity(input, beta=torch.tensor([0])):
 | 
				
			||||||
    """:math:`f(x) = x`"""
 | 
					    """:math:`f(x) = x`"""
 | 
				
			||||||
    return input
 | 
					    return input
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@register_activation
 | 
					@register_activation
 | 
				
			||||||
def sigmoid_beta(input, beta=10):
 | 
					# @torch.jit.script
 | 
				
			||||||
 | 
					def sigmoid_beta(input, beta=torch.tensor([10])):
 | 
				
			||||||
    """:math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
 | 
					    """:math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Keyword Arguments:
 | 
					    Keyword Arguments:
 | 
				
			||||||
        beta (float): Parameter :math:`\\beta`
 | 
					        beta (float): Parameter :math:`\\beta`
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    out = torch.reciprocal(1.0 + torch.exp(-beta * input))
 | 
					    out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * input))
 | 
				
			||||||
    return out
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@register_activation
 | 
					@register_activation
 | 
				
			||||||
def swish_beta(input, beta=10):
 | 
					# @torch.jit.script
 | 
				
			||||||
 | 
					def swish_beta(input, beta=torch.tensor([10])):
 | 
				
			||||||
    """:math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
 | 
					    """:math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Keyword Arguments:
 | 
					    Keyword Arguments:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,13 +3,15 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# @torch.jit.script
 | 
				
			||||||
def wtac(distances, labels):
 | 
					def wtac(distances, labels):
 | 
				
			||||||
    winning_indices = torch.min(distances, dim=1).indices
 | 
					    winning_indices = torch.min(distances, dim=1).indices
 | 
				
			||||||
    winning_labels = labels[winning_indices].squeeze()
 | 
					    winning_labels = labels[winning_indices].squeeze()
 | 
				
			||||||
    return winning_labels
 | 
					    return winning_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# @torch.jit.script
 | 
				
			||||||
def knnc(distances, labels, k):
 | 
					def knnc(distances, labels, k):
 | 
				
			||||||
    winning_indices = torch.topk(-distances, k=k, dim=1).indices
 | 
					    winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
 | 
				
			||||||
    winning_labels = labels[winning_indices].squeeze()
 | 
					    winning_labels = labels[winning_indices].squeeze()
 | 
				
			||||||
    return winning_labels
 | 
					    return winning_labels
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user