integrate reviews from ChristophRaab:master
This commit is contained in:
parent
00615ae837
commit
c204bc8e1f
@ -1,12 +1,9 @@
|
||||
import torch
|
||||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||||
tangent_distance)
|
||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||
from prototorch.functions.distances import euclidean_distance_matrix
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
r""" Generalized Tangent Learning Vector Quantization
|
||||
|
||||
@ -122,12 +119,12 @@ class GTLVQ(nn.Module):
|
||||
subspaces = subspace[:, :num_subspaces]
|
||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||
|
||||
def init_local_subspace(self, data, num_subspaces, num_protos):
|
||||
data = data - torch.mean(data, dim=0)
|
||||
_, _, v = torch.svd(data, some=False)
|
||||
v = v[:, :num_subspaces]
|
||||
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
|
||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||
def init_local_subspace(self, data,num_subspaces,num_protos):
|
||||
data = data - torch.mean(data,dim=0)
|
||||
_,_,v = torch.svd(data,some=False)
|
||||
v = v[:,:num_subspaces]
|
||||
subspaces = v.unsqueeze(0).repeat_interleave(num_protos,0)
|
||||
self.subspaces = nn.Parameter(subspaces,requires_grad=True)
|
||||
|
||||
def global_tangent_distances(self, x):
|
||||
# Tangent Projection
|
||||
@ -151,7 +148,7 @@ class GTLVQ(nn.Module):
|
||||
diff = (x - protos)
|
||||
diff = diff.permute([1, 0, 2])
|
||||
diff = torch.bmm(diff, projectors)
|
||||
diff = torch.norm(diff, 2, dim=-1).T
|
||||
diff = torch.norm(diff,2,dim=-1).T
|
||||
return diff
|
||||
|
||||
def get_parameters(self):
|
||||
|
Loading…
Reference in New Issue
Block a user