Codacy Bug Report fixes
This commit is contained in:
		| @@ -7,10 +7,7 @@ Siamnese fashion | |||||||
| For more info about GTLVQ see: | For more info about GTLVQ see: | ||||||
| DOI:10.1109/IJCNN.2016.7727534 | DOI:10.1109/IJCNN.2016.7727534 | ||||||
| """ | """ | ||||||
| import sys |  | ||||||
|  |  | ||||||
| from torch.nn import parameter |  | ||||||
| from matplotlib.pyplot import fill |  | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
|   | |||||||
| @@ -24,7 +24,7 @@ def orthogonalization(tensors): | |||||||
|     return out |     return out | ||||||
|  |  | ||||||
|  |  | ||||||
| def trace_normalization(tensors, epsilon=[1e-10]): | def trace_normalization(tensors): | ||||||
|     r""" Trace normalization |     r""" Trace normalization | ||||||
|     """ |     """ | ||||||
|     epsilon = torch.tensor([1e-10], dtype=torch.float64) |     epsilon = torch.tensor([1e-10], dtype=torch.float64) | ||||||
|   | |||||||
| @@ -3,7 +3,8 @@ import torch | |||||||
| from prototorch.modules.prototypes import Prototypes1D | from prototorch.modules.prototypes import Prototypes1D | ||||||
| from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix | from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix | ||||||
| from prototorch.functions.normalization import orthogonalization | from prototorch.functions.normalization import orthogonalization | ||||||
| from prototorch.functions.helper import _check_shapes,_int_and_mixed_shape | from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape | ||||||
|  |  | ||||||
|  |  | ||||||
| class GTLVQ(nn.Module): | class GTLVQ(nn.Module): | ||||||
|     r""" Generalized Tangent Learning Vector Quantization |     r""" Generalized Tangent Learning Vector Quantization | ||||||
| @@ -19,6 +20,10 @@ class GTLVQ(nn.Module): | |||||||
|     prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional) |     prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional) | ||||||
|         prototype data for initalization of the prototypes used in GTLVQ. |         prototype data for initalization of the prototypes used in GTLVQ. | ||||||
|  |  | ||||||
|  |     subspace_size: int (default=256,optional) | ||||||
|  |         Subspace dimension of the Projectors. Currently only supported | ||||||
|  |         with tagnent_projection_type=global. | ||||||
|  |  | ||||||
|     tangent_projection_type: string |     tangent_projection_type: string | ||||||
|         Specifies the tangent projection type |         Specifies the tangent projection type | ||||||
|         options:    local |         options:    local | ||||||
| @@ -82,13 +87,9 @@ class GTLVQ(nn.Module): | |||||||
|         self.tpt = tangent_projection_type |         self.tpt = tangent_projection_type | ||||||
|         with torch.no_grad(): |         with torch.no_grad(): | ||||||
|             if self.tpt == 'local' or self.tpt == 'local_proj': |             if self.tpt == 'local' or self.tpt == 'local_proj': | ||||||
|                 self.subspaces = torch.nn.Parameter( |                 self.init_local_subspace(subspace_data) | ||||||
|                     self.init_local_subspace( |  | ||||||
|                         subspace_data).clone().detach().requires_grad_(True)) |  | ||||||
|             elif self.tpt == 'global': |             elif self.tpt == 'global': | ||||||
|                 self.subspaces = torch.nn.Parameter( |                 self.init_gobal_subspace(subspace_data, subspace_size) | ||||||
|                     self.init_gobal_subspace( |  | ||||||
|                         subspace_data).clone().detach().requires_grad_(True)) |  | ||||||
|             else: |             else: | ||||||
|                 self.subspaces = None |                 self.subspaces = None | ||||||
|  |  | ||||||
| @@ -125,13 +126,17 @@ class GTLVQ(nn.Module): | |||||||
|     def init_gobal_subspace(self, data, num_subspaces): |     def init_gobal_subspace(self, data, num_subspaces): | ||||||
|         _, _, v = torch.svd(data) |         _, _, v = torch.svd(data) | ||||||
|         subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T |         subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T | ||||||
|         return subspace[:, :num_subspaces] |         subspaces = subspace[:, :num_subspaces] | ||||||
|  |         self.subspaces = torch.nn.Parameter( | ||||||
|  |             subspaces).clone().detach().requires_grad_(True) | ||||||
|  |  | ||||||
|     def init_local_subspace(self, data): |     def init_local_subspace(self, data): | ||||||
|         _, _, v = torch.svd(data) |         _, _, v = torch.svd(data) | ||||||
|         inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T |         inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T | ||||||
|         return inital_projector.unsqueeze(0).repeat_interleave( |         subspaces = inital_projector.unsqueeze(0).repeat_interleave( | ||||||
|             self.num_protos, 0) |             self.num_protos, 0) | ||||||
|  |         self.subspaces = torch.nn.Parameter( | ||||||
|  |             subspaces).clone().detach().requires_grad_(True) | ||||||
|  |  | ||||||
|     def global_tangent_distances(self, x): |     def global_tangent_distances(self, x): | ||||||
|         # Tangent Projection |         # Tangent Projection | ||||||
| @@ -154,13 +159,11 @@ class GTLVQ(nn.Module): | |||||||
|         # Origin Author: |         # Origin Author: | ||||||
|  |  | ||||||
|         signal_shape, signal_int_shape = _int_and_mixed_shape(signals) |         signal_shape, signal_int_shape = _int_and_mixed_shape(signals) | ||||||
|         proto_shape, proto_int_shape = _int_and_mixed_shape(protos) |         _, proto_int_shape = _int_and_mixed_shape(protos) | ||||||
|  |  | ||||||
|         # check if the shapes are correct |         # check if the shapes are correct | ||||||
|         _check_shapes(signal_int_shape, proto_int_shape) |         _check_shapes(signal_int_shape, proto_int_shape) | ||||||
|  |  | ||||||
|         atom_axes = list(range(3, len(signal_int_shape))) |  | ||||||
|  |  | ||||||
|         # Tangent Data Projections |         # Tangent Data Projections | ||||||
|         projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1) |         projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1) | ||||||
|         data = signals.squeeze(2).permute([1, 0, 2]) |         data = signals.squeeze(2).permute([1, 0, 2]) | ||||||
| @@ -170,7 +173,7 @@ class GTLVQ(nn.Module): | |||||||
|         projected_diff = torch.reshape( |         projected_diff = torch.reshape( | ||||||
|             diff, (signal_shape[1], signal_shape[0], signal_shape[2]) + |             diff, (signal_shape[1], signal_shape[0], signal_shape[2]) + | ||||||
|             signal_shape[3:]) |             signal_shape[3:]) | ||||||
|         diss = torch.norm(projected_diff,2,dim=-1) |         diss = torch.norm(projected_diff, 2, dim=-1) | ||||||
|         return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1) |         return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1) | ||||||
|  |  | ||||||
|     def get_parameters(self): |     def get_parameters(self): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user