integrate reviews from ChristophRaab:master

This commit is contained in:
Alexander Engelsberger 2021-05-27 09:43:02 +02:00
parent 00615ae837
commit c204bc8e1f

View File

@ -1,12 +1,9 @@
import torch import torch
from prototorch.functions.distances import (euclidean_distance_matrix, from prototorch.functions.distances import euclidean_distance_matrix
tangent_distance)
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D
from torch import nn from torch import nn
class GTLVQ(nn.Module): class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization r""" Generalized Tangent Learning Vector Quantization
@ -122,12 +119,12 @@ class GTLVQ(nn.Module):
subspaces = subspace[:, :num_subspaces] subspaces = subspace[:, :num_subspaces]
self.subspaces = nn.Parameter(subspaces, requires_grad=True) self.subspaces = nn.Parameter(subspaces, requires_grad=True)
def init_local_subspace(self, data, num_subspaces, num_protos): def init_local_subspace(self, data,num_subspaces,num_protos):
data = data - torch.mean(data, dim=0) data = data - torch.mean(data,dim=0)
_, _, v = torch.svd(data, some=False) _,_,v = torch.svd(data,some=False)
v = v[:, :num_subspaces] v = v[:,:num_subspaces]
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0) subspaces = v.unsqueeze(0).repeat_interleave(num_protos,0)
self.subspaces = nn.Parameter(subspaces, requires_grad=True) self.subspaces = nn.Parameter(subspaces,requires_grad=True)
def global_tangent_distances(self, x): def global_tangent_distances(self, x):
# Tangent Projection # Tangent Projection
@ -151,7 +148,7 @@ class GTLVQ(nn.Module):
diff = (x - protos) diff = (x - protos)
diff = diff.permute([1, 0, 2]) diff = diff.permute([1, 0, 2])
diff = torch.bmm(diff, projectors) diff = torch.bmm(diff, projectors)
diff = torch.norm(diff, 2, dim=-1).T diff = torch.norm(diff,2,dim=-1).T
return diff return diff
def get_parameters(self): def get_parameters(self):