refactored gtlvq from ChristophRaab:master
This commit is contained in:
parent
9f5f0d12dd
commit
00615ae837
@ -79,45 +79,35 @@ class GTLVQ(nn.Module):
|
|||||||
super(GTLVQ, self).__init__()
|
super(GTLVQ, self).__init__()
|
||||||
|
|
||||||
self.num_protos = num_classes * prototypes_per_class
|
self.num_protos = num_classes * prototypes_per_class
|
||||||
|
self.num_protos_class = prototypes_per_class
|
||||||
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
||||||
self.feature_dim = feature_dim
|
self.feature_dim = feature_dim
|
||||||
|
self.num_classes = num_classes
|
||||||
|
|
||||||
|
self.cls = Prototypes1D(
|
||||||
|
input_dim=feature_dim,
|
||||||
|
prototypes_per_class=prototypes_per_class,
|
||||||
|
nclasses=num_classes,
|
||||||
|
prototype_initializer="stratified_mean",
|
||||||
|
data=prototype_data,
|
||||||
|
)
|
||||||
|
|
||||||
if subspace_data is None:
|
if subspace_data is None:
|
||||||
raise ValueError("Init Data must be specified!")
|
raise ValueError("Init Data must be specified!")
|
||||||
|
|
||||||
self.tpt = tangent_projection_type
|
self.tpt = tangent_projection_type
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.tpt == "local" or self.tpt == "local_proj":
|
if self.tpt == "local":
|
||||||
self.init_local_subspace(subspace_data)
|
self.init_local_subspace(subspace_data, subspace_size,
|
||||||
|
self.num_protos)
|
||||||
elif self.tpt == "global":
|
elif self.tpt == "global":
|
||||||
self.init_gobal_subspace(subspace_data, subspace_size)
|
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||||
else:
|
else:
|
||||||
self.subspaces = None
|
self.subspaces = None
|
||||||
|
|
||||||
# Hypothesis-Margin-Classifier
|
|
||||||
self.cls = Prototypes1D(
|
|
||||||
input_dim=feature_dim,
|
|
||||||
prototypes_per_class=prototypes_per_class,
|
|
||||||
num_classes=num_classes,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=prototype_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Tangent Projection
|
if self.tpt == "local":
|
||||||
if self.tpt == "local_proj":
|
dis = self.local_tangent_distances(x)
|
||||||
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
|
||||||
1).unsqueeze(2))
|
|
||||||
dis, proj_x = self.local_tangent_projection(x_conform)
|
|
||||||
|
|
||||||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
|
||||||
self.feature_dim)
|
|
||||||
return proj_x, dis
|
|
||||||
elif self.tpt == "local":
|
|
||||||
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
|
||||||
1).unsqueeze(2))
|
|
||||||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
|
||||||
self.subspaces)
|
|
||||||
elif self.tpt == "gloabl":
|
elif self.tpt == "gloabl":
|
||||||
dis = self.global_tangent_distances(x)
|
dis = self.global_tangent_distances(x)
|
||||||
else:
|
else:
|
||||||
@ -130,16 +120,14 @@ class GTLVQ(nn.Module):
|
|||||||
_, _, v = torch.svd(data)
|
_, _, v = torch.svd(data)
|
||||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||||
subspaces = subspace[:, :num_subspaces]
|
subspaces = subspace[:, :num_subspaces]
|
||||||
self.subspaces = (torch.nn.Parameter(
|
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||||
subspaces).clone().detach().requires_grad_(True))
|
|
||||||
|
|
||||||
def init_local_subspace(self, data):
|
def init_local_subspace(self, data, num_subspaces, num_protos):
|
||||||
_, _, v = torch.svd(data)
|
data = data - torch.mean(data, dim=0)
|
||||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
_, _, v = torch.svd(data, some=False)
|
||||||
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
v = v[:, :num_subspaces]
|
||||||
self.num_protos, 0)
|
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
|
||||||
self.subspaces = (torch.nn.Parameter(
|
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||||
subspaces).clone().detach().requires_grad_(True))
|
|
||||||
|
|
||||||
def global_tangent_distances(self, x):
|
def global_tangent_distances(self, x):
|
||||||
# Tangent Projection
|
# Tangent Projection
|
||||||
@ -150,33 +138,21 @@ class GTLVQ(nn.Module):
|
|||||||
# Euclidean Distance
|
# Euclidean Distance
|
||||||
return euclidean_distance_matrix(x, projected_prototypes)
|
return euclidean_distance_matrix(x, projected_prototypes)
|
||||||
|
|
||||||
def local_tangent_projection(self, signals):
|
def local_tangent_distances(self, x):
|
||||||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
|
||||||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
|
||||||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
|
||||||
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
|
||||||
# subspace should be orthogonalized
|
|
||||||
# Origin Source Code
|
|
||||||
# Origin Author:
|
|
||||||
protos = self.cls.prototypes
|
|
||||||
subspaces = self.subspaces
|
|
||||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
|
||||||
_, proto_int_shape = _int_and_mixed_shape(protos)
|
|
||||||
|
|
||||||
# check if the shapes are correct
|
# Tangent Distance
|
||||||
_check_shapes(signal_int_shape, proto_int_shape)
|
x = x.unsqueeze(1).expand(x.size(0), self.cls.prototypes.size(0),
|
||||||
|
x.size(-1))
|
||||||
# Tangent Data Projections
|
protos = self.cls.prototypes.unsqueeze(0).expand(
|
||||||
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
|
x.size(0), self.cls.prototypes.size(0), x.size(-1))
|
||||||
data = signals.squeeze(2).permute([1, 0, 2])
|
projectors = torch.eye(
|
||||||
projected_data = torch.bmm(data, subspaces)
|
self.subspaces.shape[-2], device=x.device) - torch.bmm(
|
||||||
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
|
self.subspaces, self.subspaces.permute([0, 2, 1]))
|
||||||
diff = projected_data - projected_protos
|
diff = (x - protos)
|
||||||
projected_diff = torch.reshape(
|
diff = diff.permute([1, 0, 2])
|
||||||
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
diff = torch.bmm(diff, projectors)
|
||||||
signal_shape[3:])
|
diff = torch.norm(diff, 2, dim=-1).T
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
return diff
|
||||||
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
|
|
||||||
|
|
||||||
def get_parameters(self):
|
def get_parameters(self):
|
||||||
return {
|
return {
|
||||||
|
Loading…
Reference in New Issue
Block a user