Add basic prototorch functions needed for GLVQ
This commit is contained in:
		
							
								
								
									
										48
									
								
								prototorch/functions/activations.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								prototorch/functions/activations.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,48 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch activation functions."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ACTIVATIONS = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def register_activation(func):
 | 
				
			||||||
 | 
					    ACTIVATIONS[func.__name__] = func
 | 
				
			||||||
 | 
					    return func
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@register_activation
 | 
				
			||||||
 | 
					def identity(input, **kwargs):
 | 
				
			||||||
 | 
					    """:math:`f(x) = x`"""
 | 
				
			||||||
 | 
					    return input
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@register_activation
 | 
				
			||||||
 | 
					def sigmoid_beta(input, beta=10):
 | 
				
			||||||
 | 
					    """:math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Keyword Arguments:
 | 
				
			||||||
 | 
					        beta (float): Parameter :math:`\\beta`
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    out = torch.reciprocal(1.0 + torch.exp(-beta * input))
 | 
				
			||||||
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@register_activation
 | 
				
			||||||
 | 
					def swish_beta(input, beta=10):
 | 
				
			||||||
 | 
					    """:math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Keyword Arguments:
 | 
				
			||||||
 | 
					        beta (float): Parameter :math:`\\beta`
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    out = input * sigmoid_beta(input, beta=beta)
 | 
				
			||||||
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_activation(funcname):
 | 
				
			||||||
 | 
					    if callable(funcname):
 | 
				
			||||||
 | 
					        return funcname
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if funcname in ACTIVATIONS:
 | 
				
			||||||
 | 
					            return ACTIVATIONS.get(funcname)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise NameError(f'Activation {funcname} was not found.')
 | 
				
			||||||
							
								
								
									
										15
									
								
								prototorch/functions/competitions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								prototorch/functions/competitions.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch competition functions."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def wtac(distances, labels):
 | 
				
			||||||
 | 
					    winning_indices = torch.min(distances, dim=1).indices
 | 
				
			||||||
 | 
					    winning_labels = labels[winning_indices].squeeze()
 | 
				
			||||||
 | 
					    return winning_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def knnc(distances, labels, k):
 | 
				
			||||||
 | 
					    winning_indices = torch.topk(-distances, k=k, dim=1).indices
 | 
				
			||||||
 | 
					    winning_labels = labels[winning_indices].squeeze()
 | 
				
			||||||
 | 
					    return winning_labels
 | 
				
			||||||
@@ -33,11 +33,14 @@ def lpnorm_distance(x, y, p):
 | 
				
			|||||||
    Expected dimension of x is 2.
 | 
					    Expected dimension of x is 2.
 | 
				
			||||||
    Expected dimension of y is 2.
 | 
					    Expected dimension of y is 2.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    expanded_x = x.unsqueeze(dim=1)
 | 
					    # # DEPRECATED in favor of torch.cdist
 | 
				
			||||||
    batchwise_difference = y - expanded_x
 | 
					    # expanded_x = x.unsqueeze(dim=1)
 | 
				
			||||||
    differences_raised = torch.pow(batchwise_difference, p)
 | 
					    # batchwise_difference = y - expanded_x
 | 
				
			||||||
    distances_raised = torch.sum(differences_raised, axis=2)
 | 
					    # differences_raised = torch.pow(batchwise_difference, p)
 | 
				
			||||||
    distances = torch.pow(distances_raised, 1.0 / p)
 | 
					    # distances_raised = torch.sum(differences_raised, axis=2)
 | 
				
			||||||
 | 
					    # distances = torch.pow(distances_raised, 1.0 / p)
 | 
				
			||||||
 | 
					    # return distances
 | 
				
			||||||
 | 
					    distances = torch.cdist(x, y, p=p)
 | 
				
			||||||
    return distances
 | 
					    return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										93
									
								
								prototorch/functions/initializers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								prototorch/functions/initializers.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,93 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch initialization functions."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from itertools import chain
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					INITIALIZERS = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def register_initializer(func):
 | 
				
			||||||
 | 
					    INITIALIZERS[func.__name__] = func
 | 
				
			||||||
 | 
					    return func
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def labels_from(distribution):
 | 
				
			||||||
 | 
					    """Takes a distribution tensor and returns a labels tensor."""
 | 
				
			||||||
 | 
					    nclasses = distribution.shape[0]
 | 
				
			||||||
 | 
					    llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
 | 
				
			||||||
 | 
					    # labels = [l for cl in llist for l in cl]  # flatten the list of lists
 | 
				
			||||||
 | 
					    labels = list(chain(*llist))  # flatten using itertools.chain
 | 
				
			||||||
 | 
					    return torch.tensor(labels, requires_grad=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@register_initializer
 | 
				
			||||||
 | 
					def ones(x_train, y_train, prototype_distribution):
 | 
				
			||||||
 | 
					    nprotos = torch.sum(prototype_distribution)
 | 
				
			||||||
 | 
					    protos = torch.ones(nprotos, *x_train.shape[1:])
 | 
				
			||||||
 | 
					    plabels = labels_from(prototype_distribution)
 | 
				
			||||||
 | 
					    return protos, plabels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@register_initializer
 | 
				
			||||||
 | 
					def zeros(x_train, y_train, prototype_distribution):
 | 
				
			||||||
 | 
					    nprotos = torch.sum(prototype_distribution)
 | 
				
			||||||
 | 
					    protos = torch.zeros(nprotos, *x_train.shape[1:])
 | 
				
			||||||
 | 
					    plabels = labels_from(prototype_distribution)
 | 
				
			||||||
 | 
					    return protos, plabels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@register_initializer
 | 
				
			||||||
 | 
					def rand(x_train, y_train, prototype_distribution):
 | 
				
			||||||
 | 
					    nprotos = torch.sum(prototype_distribution)
 | 
				
			||||||
 | 
					    protos = torch.rand(nprotos, *x_train.shape[1:])
 | 
				
			||||||
 | 
					    plabels = labels_from(prototype_distribution)
 | 
				
			||||||
 | 
					    return protos, plabels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@register_initializer
 | 
				
			||||||
 | 
					def randn(x_train, y_train, prototype_distribution):
 | 
				
			||||||
 | 
					    nprotos = torch.sum(prototype_distribution)
 | 
				
			||||||
 | 
					    protos = torch.randn(nprotos, *x_train.shape[1:])
 | 
				
			||||||
 | 
					    plabels = labels_from(prototype_distribution)
 | 
				
			||||||
 | 
					    return protos, plabels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@register_initializer
 | 
				
			||||||
 | 
					def stratified_mean(x_train, y_train, prototype_distribution):
 | 
				
			||||||
 | 
					    nprotos = torch.sum(prototype_distribution)
 | 
				
			||||||
 | 
					    pdim = x_train.shape[1]
 | 
				
			||||||
 | 
					    protos = torch.empty(nprotos, pdim)
 | 
				
			||||||
 | 
					    plabels = labels_from(prototype_distribution)
 | 
				
			||||||
 | 
					    for i, l in enumerate(plabels):
 | 
				
			||||||
 | 
					        xl = x_train[y_train == l]
 | 
				
			||||||
 | 
					        mean_xl = torch.mean(xl, dim=0)
 | 
				
			||||||
 | 
					        protos[i] = mean_xl
 | 
				
			||||||
 | 
					    return protos, plabels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@register_initializer
 | 
				
			||||||
 | 
					def stratified_random(x_train, y_train, prototype_distribution):
 | 
				
			||||||
 | 
					    gen = torch.manual_seed(torch.initial_seed())
 | 
				
			||||||
 | 
					    nprotos = torch.sum(prototype_distribution)
 | 
				
			||||||
 | 
					    pdim = x_train.shape[1]
 | 
				
			||||||
 | 
					    protos = torch.empty(nprotos, pdim)
 | 
				
			||||||
 | 
					    plabels = labels_from(prototype_distribution)
 | 
				
			||||||
 | 
					    for i, l in enumerate(plabels):
 | 
				
			||||||
 | 
					        xl = x_train[y_train == l]
 | 
				
			||||||
 | 
					        rand_index = torch.zeros(1).long().random_(0,
 | 
				
			||||||
 | 
					                                                   xl.shape[1] - 1,
 | 
				
			||||||
 | 
					                                                   generator=gen)
 | 
				
			||||||
 | 
					        random_xl = xl[rand_index]
 | 
				
			||||||
 | 
					        protos[i] = random_xl
 | 
				
			||||||
 | 
					    return protos, plabels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_initializer(funcname):
 | 
				
			||||||
 | 
					    if callable(funcname):
 | 
				
			||||||
 | 
					        return funcname
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        if funcname in INITIALIZERS:
 | 
				
			||||||
 | 
					            return INITIALIZERS.get(funcname)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise NameError(f'Initializer {funcname} was not found.')
 | 
				
			||||||
							
								
								
									
										25
									
								
								prototorch/functions/losses.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								prototorch/functions/losses.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch loss functions."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def glvq_loss(distances, target_labels, prototype_labels):
 | 
				
			||||||
 | 
					    """GLVQ loss function with support for one-hot labels."""
 | 
				
			||||||
 | 
					    matcher = torch.eq(target_labels.unsqueeze(dim=1), prototype_labels)
 | 
				
			||||||
 | 
					    if prototype_labels.ndim == 2:
 | 
				
			||||||
 | 
					        # if the labels are one-hot vectors
 | 
				
			||||||
 | 
					        nclasses = target_labels.size()[1]
 | 
				
			||||||
 | 
					        matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
 | 
				
			||||||
 | 
					    not_matcher = torch.bitwise_not(matcher)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    dplus_criterion = distances * matcher > 0.0
 | 
				
			||||||
 | 
					    dminus_criterion = distances * not_matcher > 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    inf = torch.full_like(distances, fill_value=float('inf'))
 | 
				
			||||||
 | 
					    distances_to_wpluses = torch.where(dplus_criterion, distances, inf)
 | 
				
			||||||
 | 
					    distances_to_wminuses = torch.where(dminus_criterion, distances, inf)
 | 
				
			||||||
 | 
					    dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
 | 
				
			||||||
 | 
					    dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    mu = (dpluses - dminuses) / (dpluses + dminuses)
 | 
				
			||||||
 | 
					    return mu
 | 
				
			||||||
		Reference in New Issue
	
	Block a user