Remove usage of Prototype1D

Update Iris example to new component API
Update Tecator example to new component API
Update LGMLVQ example to new component API
Update GTLVQ to new component API
This commit is contained in:
Alexander Engelsberger
2021-05-28 15:57:26 +02:00
parent 94fe4435a8
commit 40ef3aeda2
5 changed files with 75 additions and 74 deletions

View File

@@ -1,9 +1,10 @@
import torch
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.distances import euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D
from torch import nn
class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization
@@ -81,13 +82,13 @@ class GTLVQ(nn.Module):
self.feature_dim = feature_dim
self.num_classes = num_classes
self.cls = Prototypes1D(
input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
nclasses=num_classes,
prototype_initializer="stratified_mean",
data=prototype_data,
)
cls_initializer = StratifiedMeanInitializer(prototype_data)
cls_distribution = {
"num_classes": num_classes,
"prototypes_per_class": prototypes_per_class,
}
self.cls = LabeledComponents(cls_distribution, cls_initializer)
if subspace_data is None:
raise ValueError("Init Data must be specified!")
@@ -119,12 +120,12 @@ class GTLVQ(nn.Module):
subspaces = subspace[:, :num_subspaces]
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
def init_local_subspace(self, data,num_subspaces,num_protos):
data = data - torch.mean(data,dim=0)
_,_,v = torch.svd(data,some=False)
v = v[:,:num_subspaces]
subspaces = v.unsqueeze(0).repeat_interleave(num_protos,0)
self.subspaces = nn.Parameter(subspaces,requires_grad=True)
def init_local_subspace(self, data, num_subspaces, num_protos):
data = data - torch.mean(data, dim=0)
_, _, v = torch.svd(data, some=False)
v = v[:, :num_subspaces]
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
def global_tangent_distances(self, x):
# Tangent Projection
@@ -138,22 +139,23 @@ class GTLVQ(nn.Module):
def local_tangent_distances(self, x):
# Tangent Distance
x = x.unsqueeze(1).expand(x.size(0), self.cls.prototypes.size(0),
x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components,
x.size(-1))
protos = self.cls.prototypes.unsqueeze(0).expand(
x.size(0), self.cls.prototypes.size(0), x.size(-1))
protos = self.cls()[0].unsqueeze(0).expand(x.size(0),
self.cls.num_components,
x.size(-1))
projectors = torch.eye(
self.subspaces.shape[-2], device=x.device) - torch.bmm(
self.subspaces, self.subspaces.permute([0, 2, 1]))
diff = (x - protos)
diff = diff.permute([1, 0, 2])
diff = torch.bmm(diff, projectors)
diff = torch.norm(diff,2,dim=-1).T
diff = torch.norm(diff, 2, dim=-1).T
return diff
def get_parameters(self):
return {
"params": self.cls.prototypes,
"params": self.cls.components,
}, {
"params": self.subspaces
}