integrate reviews from ChristophRaab:master
This commit is contained in:
parent
00615ae837
commit
c204bc8e1f
@ -1,12 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
from prototorch.functions.distances import euclidean_distance_matrix
|
||||||
tangent_distance)
|
|
||||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
|
||||||
from prototorch.functions.normalization import orthogonalization
|
from prototorch.functions.normalization import orthogonalization
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
class GTLVQ(nn.Module):
|
class GTLVQ(nn.Module):
|
||||||
r""" Generalized Tangent Learning Vector Quantization
|
r""" Generalized Tangent Learning Vector Quantization
|
||||||
|
|
||||||
@ -122,12 +119,12 @@ class GTLVQ(nn.Module):
|
|||||||
subspaces = subspace[:, :num_subspaces]
|
subspaces = subspace[:, :num_subspaces]
|
||||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||||
|
|
||||||
def init_local_subspace(self, data, num_subspaces, num_protos):
|
def init_local_subspace(self, data,num_subspaces,num_protos):
|
||||||
data = data - torch.mean(data, dim=0)
|
data = data - torch.mean(data,dim=0)
|
||||||
_, _, v = torch.svd(data, some=False)
|
_,_,v = torch.svd(data,some=False)
|
||||||
v = v[:, :num_subspaces]
|
v = v[:,:num_subspaces]
|
||||||
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
|
subspaces = v.unsqueeze(0).repeat_interleave(num_protos,0)
|
||||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
self.subspaces = nn.Parameter(subspaces,requires_grad=True)
|
||||||
|
|
||||||
def global_tangent_distances(self, x):
|
def global_tangent_distances(self, x):
|
||||||
# Tangent Projection
|
# Tangent Projection
|
||||||
@ -151,7 +148,7 @@ class GTLVQ(nn.Module):
|
|||||||
diff = (x - protos)
|
diff = (x - protos)
|
||||||
diff = diff.permute([1, 0, 2])
|
diff = diff.permute([1, 0, 2])
|
||||||
diff = torch.bmm(diff, projectors)
|
diff = torch.bmm(diff, projectors)
|
||||||
diff = torch.norm(diff, 2, dim=-1).T
|
diff = torch.norm(diff,2,dim=-1).T
|
||||||
return diff
|
return diff
|
||||||
|
|
||||||
def get_parameters(self):
|
def get_parameters(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user