refactored gtlvq from ChristophRaab:master

This commit is contained in:
Alexander Engelsberger 2021-05-27 09:39:59 +02:00
parent 9f5f0d12dd
commit 00615ae837

View File

@ -79,45 +79,35 @@ class GTLVQ(nn.Module):
super(GTLVQ, self).__init__() super(GTLVQ, self).__init__()
self.num_protos = num_classes * prototypes_per_class self.num_protos = num_classes * prototypes_per_class
self.num_protos_class = prototypes_per_class
self.subspace_size = feature_dim if subspace_size is None else subspace_size self.subspace_size = feature_dim if subspace_size is None else subspace_size
self.feature_dim = feature_dim self.feature_dim = feature_dim
self.num_classes = num_classes
self.cls = Prototypes1D(
input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
nclasses=num_classes,
prototype_initializer="stratified_mean",
data=prototype_data,
)
if subspace_data is None: if subspace_data is None:
raise ValueError("Init Data must be specified!") raise ValueError("Init Data must be specified!")
self.tpt = tangent_projection_type self.tpt = tangent_projection_type
with torch.no_grad(): with torch.no_grad():
if self.tpt == "local" or self.tpt == "local_proj": if self.tpt == "local":
self.init_local_subspace(subspace_data) self.init_local_subspace(subspace_data, subspace_size,
self.num_protos)
elif self.tpt == "global": elif self.tpt == "global":
self.init_gobal_subspace(subspace_data, subspace_size) self.init_gobal_subspace(subspace_data, subspace_size)
else: else:
self.subspaces = None self.subspaces = None
# Hypothesis-Margin-Classifier
self.cls = Prototypes1D(
input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
num_classes=num_classes,
prototype_initializer="stratified_mean",
data=prototype_data,
)
def forward(self, x): def forward(self, x):
# Tangent Projection if self.tpt == "local":
if self.tpt == "local_proj": dis = self.local_tangent_distances(x)
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2))
dis, proj_x = self.local_tangent_projection(x_conform)
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
self.feature_dim)
return proj_x, dis
elif self.tpt == "local":
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2))
dis = tangent_distance(x_conform, self.cls.prototypes,
self.subspaces)
elif self.tpt == "gloabl": elif self.tpt == "gloabl":
dis = self.global_tangent_distances(x) dis = self.global_tangent_distances(x)
else: else:
@ -130,16 +120,14 @@ class GTLVQ(nn.Module):
_, _, v = torch.svd(data) _, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = subspace[:, :num_subspaces] subspaces = subspace[:, :num_subspaces]
self.subspaces = (torch.nn.Parameter( self.subspaces = nn.Parameter(subspaces, requires_grad=True)
subspaces).clone().detach().requires_grad_(True))
def init_local_subspace(self, data): def init_local_subspace(self, data, num_subspaces, num_protos):
_, _, v = torch.svd(data) data = data - torch.mean(data, dim=0)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T _, _, v = torch.svd(data, some=False)
subspaces = inital_projector.unsqueeze(0).repeat_interleave( v = v[:, :num_subspaces]
self.num_protos, 0) subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
self.subspaces = (torch.nn.Parameter( self.subspaces = nn.Parameter(subspaces, requires_grad=True)
subspaces).clone().detach().requires_grad_(True))
def global_tangent_distances(self, x): def global_tangent_distances(self, x):
# Tangent Projection # Tangent Projection
@ -150,33 +138,21 @@ class GTLVQ(nn.Module):
# Euclidean Distance # Euclidean Distance
return euclidean_distance_matrix(x, projected_prototypes) return euclidean_distance_matrix(x, projected_prototypes)
def local_tangent_projection(self, signals): def local_tangent_distances(self, x):
# Note: subspaces is always assumed as transposed and must be orthogonal!
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
# subspace should be orthogonalized
# Origin Source Code
# Origin Author:
protos = self.cls.prototypes
subspaces = self.subspaces
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
_, proto_int_shape = _int_and_mixed_shape(protos)
# check if the shapes are correct # Tangent Distance
_check_shapes(signal_int_shape, proto_int_shape) x = x.unsqueeze(1).expand(x.size(0), self.cls.prototypes.size(0),
x.size(-1))
# Tangent Data Projections protos = self.cls.prototypes.unsqueeze(0).expand(
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1) x.size(0), self.cls.prototypes.size(0), x.size(-1))
data = signals.squeeze(2).permute([1, 0, 2]) projectors = torch.eye(
projected_data = torch.bmm(data, subspaces) self.subspaces.shape[-2], device=x.device) - torch.bmm(
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1) self.subspaces, self.subspaces.permute([0, 2, 1]))
diff = projected_data - projected_protos diff = (x - protos)
projected_diff = torch.reshape( diff = diff.permute([1, 0, 2])
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) + diff = torch.bmm(diff, projectors)
signal_shape[3:]) diff = torch.norm(diff, 2, dim=-1).T
diss = torch.norm(projected_diff, 2, dim=-1) return diff
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
def get_parameters(self): def get_parameters(self):
return { return {