Codacy Bug Report fixes

This commit is contained in:
Christoph
2021-01-14 10:04:43 +01:00
parent 895281aabd
commit 30dc0ea8b1
5 changed files with 36 additions and 36 deletions

View File

@@ -3,14 +3,15 @@ import torch
from prototorch.modules.prototypes import Prototypes1D
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from prototorch.functions.helper import _check_shapes,_int_and_mixed_shape
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization
Parameters
----------
num_classes: int
num_classes: int
Number of classes of the given classification problem.
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
@@ -19,7 +20,11 @@ class GTLVQ(nn.Module):
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
prototype data for initalization of the prototypes used in GTLVQ.
tangent_projection_type: string
subspace_size: int (default=256,optional)
Subspace dimension of the Projectors. Currently only supported
with tagnent_projection_type=global.
tangent_projection_type: string
Specifies the tangent projection type
options: local
local_proj
@@ -28,33 +33,33 @@ class GTLVQ(nn.Module):
data. Only distances are available
local_proj: computs tangent distances and returns the projected data
for further use. Be careful: data is repeated by number of prototypes
global: Number of subspaces is set to one and every prototypes
uses the same.
global: Number of subspaces is set to one and every prototypes
uses the same.
prototypes_per_class: int (default=2,optional)
Number of prototypes per class
feature_dim: int (default=256)
Dimensionality of the feature space specified as integer.
Prototype dimension.
Dimensionality of the feature space specified as integer.
Prototype dimension.
Notes
-----
The GTLVQ [1] is a prototype-based classification learning model. The
GTLVQ uses the Tangent-Distances for a local point approximation
of an assumed data manifold via prototypial representations.
The GTLVQ [1] is a prototype-based classification learning model. The
GTLVQ uses the Tangent-Distances for a local point approximation
of an assumed data manifold via prototypial representations.
The GTLVQ requires subspace projectors for transforming the data
and prototypes into the affine subspace. Every prototype is
equipped with a specific subpspace and represents a point
and prototypes into the affine subspace. Every prototype is
equipped with a specific subpspace and represents a point
approximation of the assumed manifold.
In practice prototypes and data are projected on this manifold
In practice prototypes and data are projected on this manifold
and pairwise euclidean distance computes.
References
----------
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
in classification based on manifolc. models and its relation
to tangent metric learning. In: 2017 International Joint
Conference on Neural Networks (IJCNN).
@@ -82,13 +87,9 @@ class GTLVQ(nn.Module):
self.tpt = tangent_projection_type
with torch.no_grad():
if self.tpt == 'local' or self.tpt == 'local_proj':
self.subspaces = torch.nn.Parameter(
self.init_local_subspace(
subspace_data).clone().detach().requires_grad_(True))
self.init_local_subspace(subspace_data)
elif self.tpt == 'global':
self.subspaces = torch.nn.Parameter(
self.init_gobal_subspace(
subspace_data).clone().detach().requires_grad_(True))
self.init_gobal_subspace(subspace_data, subspace_size)
else:
self.subspaces = None
@@ -125,13 +126,17 @@ class GTLVQ(nn.Module):
def init_gobal_subspace(self, data, num_subspaces):
_, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
return subspace[:, :num_subspaces]
subspaces = subspace[:, :num_subspaces]
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def init_local_subspace(self, data):
_, _, v = torch.svd(data)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
return inital_projector.unsqueeze(0).repeat_interleave(
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
self.num_protos, 0)
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def global_tangent_distances(self, x):
# Tangent Projection
@@ -154,13 +159,11 @@ class GTLVQ(nn.Module):
# Origin Author:
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
_, proto_int_shape = _int_and_mixed_shape(protos)
# check if the shapes are correct
_check_shapes(signal_int_shape, proto_int_shape)
atom_axes = list(range(3, len(signal_int_shape)))
# Tangent Data Projections
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
data = signals.squeeze(2).permute([1, 0, 2])
@@ -170,7 +173,7 @@ class GTLVQ(nn.Module):
projected_diff = torch.reshape(
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
signal_shape[3:])
diss = torch.norm(projected_diff,2,dim=-1)
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
def get_parameters(self):