integrate reviews from ChristophRaab:master

This commit is contained in:
Alexander Engelsberger 2021-05-27 09:43:02 +02:00
parent 00615ae837
commit c204bc8e1f

View File

@ -1,12 +1,9 @@
import torch
from prototorch.functions.distances import (euclidean_distance_matrix,
tangent_distance)
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.distances import euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D
from torch import nn
class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization
@ -122,12 +119,12 @@ class GTLVQ(nn.Module):
subspaces = subspace[:, :num_subspaces]
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
def init_local_subspace(self, data, num_subspaces, num_protos):
data = data - torch.mean(data, dim=0)
_, _, v = torch.svd(data, some=False)
v = v[:, :num_subspaces]
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
def init_local_subspace(self, data,num_subspaces,num_protos):
data = data - torch.mean(data,dim=0)
_,_,v = torch.svd(data,some=False)
v = v[:,:num_subspaces]
subspaces = v.unsqueeze(0).repeat_interleave(num_protos,0)
self.subspaces = nn.Parameter(subspaces,requires_grad=True)
def global_tangent_distances(self, x):
# Tangent Projection
@ -151,7 +148,7 @@ class GTLVQ(nn.Module):
diff = (x - protos)
diff = diff.permute([1, 0, 2])
diff = torch.bmm(diff, projectors)
diff = torch.norm(diff, 2, dim=-1).T
diff = torch.norm(diff,2,dim=-1).T
return diff
def get_parameters(self):