Codacy Bug Report fixes

This commit is contained in:
Christoph 2021-01-14 10:04:43 +01:00
parent 895281aabd
commit 30dc0ea8b1
5 changed files with 36 additions and 36 deletions

View File

@ -7,10 +7,7 @@ Siamnese fashion
For more info about GTLVQ see: For more info about GTLVQ see:
DOI:10.1109/IJCNN.2016.7727534 DOI:10.1109/IJCNN.2016.7727534
""" """
import sys
from torch.nn import parameter
from matplotlib.pyplot import fill
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn

View File

@ -24,7 +24,7 @@ def orthogonalization(tensors):
return out return out
def trace_normalization(tensors, epsilon=[1e-10]): def trace_normalization(tensors):
r""" Trace normalization r""" Trace normalization
""" """
epsilon = torch.tensor([1e-10], dtype=torch.float64) epsilon = torch.tensor([1e-10], dtype=torch.float64)

View File

@ -3,7 +3,8 @@ import torch
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization from prototorch.functions.normalization import orthogonalization
from prototorch.functions.helper import _check_shapes,_int_and_mixed_shape from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
class GTLVQ(nn.Module): class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization r""" Generalized Tangent Learning Vector Quantization
@ -19,6 +20,10 @@ class GTLVQ(nn.Module):
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional) prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
prototype data for initalization of the prototypes used in GTLVQ. prototype data for initalization of the prototypes used in GTLVQ.
subspace_size: int (default=256,optional)
Subspace dimension of the Projectors. Currently only supported
with tagnent_projection_type=global.
tangent_projection_type: string tangent_projection_type: string
Specifies the tangent projection type Specifies the tangent projection type
options: local options: local
@ -82,13 +87,9 @@ class GTLVQ(nn.Module):
self.tpt = tangent_projection_type self.tpt = tangent_projection_type
with torch.no_grad(): with torch.no_grad():
if self.tpt == 'local' or self.tpt == 'local_proj': if self.tpt == 'local' or self.tpt == 'local_proj':
self.subspaces = torch.nn.Parameter( self.init_local_subspace(subspace_data)
self.init_local_subspace(
subspace_data).clone().detach().requires_grad_(True))
elif self.tpt == 'global': elif self.tpt == 'global':
self.subspaces = torch.nn.Parameter( self.init_gobal_subspace(subspace_data, subspace_size)
self.init_gobal_subspace(
subspace_data).clone().detach().requires_grad_(True))
else: else:
self.subspaces = None self.subspaces = None
@ -125,13 +126,17 @@ class GTLVQ(nn.Module):
def init_gobal_subspace(self, data, num_subspaces): def init_gobal_subspace(self, data, num_subspaces):
_, _, v = torch.svd(data) _, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
return subspace[:, :num_subspaces] subspaces = subspace[:, :num_subspaces]
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def init_local_subspace(self, data): def init_local_subspace(self, data):
_, _, v = torch.svd(data) _, _, v = torch.svd(data)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
return inital_projector.unsqueeze(0).repeat_interleave( subspaces = inital_projector.unsqueeze(0).repeat_interleave(
self.num_protos, 0) self.num_protos, 0)
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def global_tangent_distances(self, x): def global_tangent_distances(self, x):
# Tangent Projection # Tangent Projection
@ -154,13 +159,11 @@ class GTLVQ(nn.Module):
# Origin Author: # Origin Author:
signal_shape, signal_int_shape = _int_and_mixed_shape(signals) signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
proto_shape, proto_int_shape = _int_and_mixed_shape(protos) _, proto_int_shape = _int_and_mixed_shape(protos)
# check if the shapes are correct # check if the shapes are correct
_check_shapes(signal_int_shape, proto_int_shape) _check_shapes(signal_int_shape, proto_int_shape)
atom_axes = list(range(3, len(signal_int_shape)))
# Tangent Data Projections # Tangent Data Projections
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1) projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
data = signals.squeeze(2).permute([1, 0, 2]) data = signals.squeeze(2).permute([1, 0, 2])
@ -170,7 +173,7 @@ class GTLVQ(nn.Module):
projected_diff = torch.reshape( projected_diff = torch.reshape(
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) + diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
signal_shape[3:]) signal_shape[3:])
diss = torch.norm(projected_diff,2,dim=-1) diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1) return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
def get_parameters(self): def get_parameters(self):