Add basic prototorch functions needed for GLVQ
This commit is contained in:
parent
f9bc4a29c9
commit
33e8f1297f
48
prototorch/functions/activations.py
Normal file
48
prototorch/functions/activations.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
"""ProtoTorch activation functions."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ACTIVATIONS = dict()
|
||||||
|
|
||||||
|
|
||||||
|
def register_activation(func):
|
||||||
|
ACTIVATIONS[func.__name__] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
@register_activation
|
||||||
|
def identity(input, **kwargs):
|
||||||
|
""":math:`f(x) = x`"""
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
@register_activation
|
||||||
|
def sigmoid_beta(input, beta=10):
|
||||||
|
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
beta (float): Parameter :math:`\\beta`
|
||||||
|
"""
|
||||||
|
out = torch.reciprocal(1.0 + torch.exp(-beta * input))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@register_activation
|
||||||
|
def swish_beta(input, beta=10):
|
||||||
|
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
beta (float): Parameter :math:`\\beta`
|
||||||
|
"""
|
||||||
|
out = input * sigmoid_beta(input, beta=beta)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def get_activation(funcname):
|
||||||
|
if callable(funcname):
|
||||||
|
return funcname
|
||||||
|
else:
|
||||||
|
if funcname in ACTIVATIONS:
|
||||||
|
return ACTIVATIONS.get(funcname)
|
||||||
|
else:
|
||||||
|
raise NameError(f'Activation {funcname} was not found.')
|
15
prototorch/functions/competitions.py
Normal file
15
prototorch/functions/competitions.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
"""ProtoTorch competition functions."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def wtac(distances, labels):
|
||||||
|
winning_indices = torch.min(distances, dim=1).indices
|
||||||
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
|
return winning_labels
|
||||||
|
|
||||||
|
|
||||||
|
def knnc(distances, labels, k):
|
||||||
|
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
||||||
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
|
return winning_labels
|
@ -33,11 +33,14 @@ def lpnorm_distance(x, y, p):
|
|||||||
Expected dimension of x is 2.
|
Expected dimension of x is 2.
|
||||||
Expected dimension of y is 2.
|
Expected dimension of y is 2.
|
||||||
"""
|
"""
|
||||||
expanded_x = x.unsqueeze(dim=1)
|
# # DEPRECATED in favor of torch.cdist
|
||||||
batchwise_difference = y - expanded_x
|
# expanded_x = x.unsqueeze(dim=1)
|
||||||
differences_raised = torch.pow(batchwise_difference, p)
|
# batchwise_difference = y - expanded_x
|
||||||
distances_raised = torch.sum(differences_raised, axis=2)
|
# differences_raised = torch.pow(batchwise_difference, p)
|
||||||
distances = torch.pow(distances_raised, 1.0 / p)
|
# distances_raised = torch.sum(differences_raised, axis=2)
|
||||||
|
# distances = torch.pow(distances_raised, 1.0 / p)
|
||||||
|
# return distances
|
||||||
|
distances = torch.cdist(x, y, p=p)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
93
prototorch/functions/initializers.py
Normal file
93
prototorch/functions/initializers.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
"""ProtoTorch initialization functions."""
|
||||||
|
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
INITIALIZERS = dict()
|
||||||
|
|
||||||
|
|
||||||
|
def register_initializer(func):
|
||||||
|
INITIALIZERS[func.__name__] = func
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def labels_from(distribution):
|
||||||
|
"""Takes a distribution tensor and returns a labels tensor."""
|
||||||
|
nclasses = distribution.shape[0]
|
||||||
|
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
||||||
|
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
||||||
|
labels = list(chain(*llist)) # flatten using itertools.chain
|
||||||
|
return torch.tensor(labels, requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
|
@register_initializer
|
||||||
|
def ones(x_train, y_train, prototype_distribution):
|
||||||
|
nprotos = torch.sum(prototype_distribution)
|
||||||
|
protos = torch.ones(nprotos, *x_train.shape[1:])
|
||||||
|
plabels = labels_from(prototype_distribution)
|
||||||
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
|
@register_initializer
|
||||||
|
def zeros(x_train, y_train, prototype_distribution):
|
||||||
|
nprotos = torch.sum(prototype_distribution)
|
||||||
|
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
||||||
|
plabels = labels_from(prototype_distribution)
|
||||||
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
|
@register_initializer
|
||||||
|
def rand(x_train, y_train, prototype_distribution):
|
||||||
|
nprotos = torch.sum(prototype_distribution)
|
||||||
|
protos = torch.rand(nprotos, *x_train.shape[1:])
|
||||||
|
plabels = labels_from(prototype_distribution)
|
||||||
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
|
@register_initializer
|
||||||
|
def randn(x_train, y_train, prototype_distribution):
|
||||||
|
nprotos = torch.sum(prototype_distribution)
|
||||||
|
protos = torch.randn(nprotos, *x_train.shape[1:])
|
||||||
|
plabels = labels_from(prototype_distribution)
|
||||||
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
|
@register_initializer
|
||||||
|
def stratified_mean(x_train, y_train, prototype_distribution):
|
||||||
|
nprotos = torch.sum(prototype_distribution)
|
||||||
|
pdim = x_train.shape[1]
|
||||||
|
protos = torch.empty(nprotos, pdim)
|
||||||
|
plabels = labels_from(prototype_distribution)
|
||||||
|
for i, l in enumerate(plabels):
|
||||||
|
xl = x_train[y_train == l]
|
||||||
|
mean_xl = torch.mean(xl, dim=0)
|
||||||
|
protos[i] = mean_xl
|
||||||
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
|
@register_initializer
|
||||||
|
def stratified_random(x_train, y_train, prototype_distribution):
|
||||||
|
gen = torch.manual_seed(torch.initial_seed())
|
||||||
|
nprotos = torch.sum(prototype_distribution)
|
||||||
|
pdim = x_train.shape[1]
|
||||||
|
protos = torch.empty(nprotos, pdim)
|
||||||
|
plabels = labels_from(prototype_distribution)
|
||||||
|
for i, l in enumerate(plabels):
|
||||||
|
xl = x_train[y_train == l]
|
||||||
|
rand_index = torch.zeros(1).long().random_(0,
|
||||||
|
xl.shape[1] - 1,
|
||||||
|
generator=gen)
|
||||||
|
random_xl = xl[rand_index]
|
||||||
|
protos[i] = random_xl
|
||||||
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
|
def get_initializer(funcname):
|
||||||
|
if callable(funcname):
|
||||||
|
return funcname
|
||||||
|
else:
|
||||||
|
if funcname in INITIALIZERS:
|
||||||
|
return INITIALIZERS.get(funcname)
|
||||||
|
else:
|
||||||
|
raise NameError(f'Initializer {funcname} was not found.')
|
25
prototorch/functions/losses.py
Normal file
25
prototorch/functions/losses.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
"""ProtoTorch loss functions."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def glvq_loss(distances, target_labels, prototype_labels):
|
||||||
|
"""GLVQ loss function with support for one-hot labels."""
|
||||||
|
matcher = torch.eq(target_labels.unsqueeze(dim=1), prototype_labels)
|
||||||
|
if prototype_labels.ndim == 2:
|
||||||
|
# if the labels are one-hot vectors
|
||||||
|
nclasses = target_labels.size()[1]
|
||||||
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
|
|
||||||
|
dplus_criterion = distances * matcher > 0.0
|
||||||
|
dminus_criterion = distances * not_matcher > 0.0
|
||||||
|
|
||||||
|
inf = torch.full_like(distances, fill_value=float('inf'))
|
||||||
|
distances_to_wpluses = torch.where(dplus_criterion, distances, inf)
|
||||||
|
distances_to_wminuses = torch.where(dminus_criterion, distances, inf)
|
||||||
|
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
|
||||||
|
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
|
||||||
|
|
||||||
|
mu = (dpluses - dminuses) / (dpluses + dminuses)
|
||||||
|
return mu
|
Loading…
Reference in New Issue
Block a user