Codacy Bug Report fixes

This commit is contained in:
Christoph 2021-01-14 10:04:43 +01:00
parent 895281aabd
commit 30dc0ea8b1
5 changed files with 36 additions and 36 deletions

View File

@ -7,10 +7,7 @@ Siamnese fashion
For more info about GTLVQ see:
DOI:10.1109/IJCNN.2016.7727534
"""
import sys
from torch.nn import parameter
from matplotlib.pyplot import fill
import numpy as np
import torch
import torch.nn as nn

View File

@ -24,7 +24,7 @@ def orthogonalization(tensors):
return out
def trace_normalization(tensors, epsilon=[1e-10]):
def trace_normalization(tensors):
r""" Trace normalization
"""
epsilon = torch.tensor([1e-10], dtype=torch.float64)

View File

@ -3,7 +3,8 @@ import torch
from prototorch.modules.prototypes import Prototypes1D
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from prototorch.functions.helper import _check_shapes,_int_and_mixed_shape
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization
@ -19,6 +20,10 @@ class GTLVQ(nn.Module):
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
prototype data for initalization of the prototypes used in GTLVQ.
subspace_size: int (default=256,optional)
Subspace dimension of the Projectors. Currently only supported
with tagnent_projection_type=global.
tangent_projection_type: string
Specifies the tangent projection type
options: local
@ -82,13 +87,9 @@ class GTLVQ(nn.Module):
self.tpt = tangent_projection_type
with torch.no_grad():
if self.tpt == 'local' or self.tpt == 'local_proj':
self.subspaces = torch.nn.Parameter(
self.init_local_subspace(
subspace_data).clone().detach().requires_grad_(True))
self.init_local_subspace(subspace_data)
elif self.tpt == 'global':
self.subspaces = torch.nn.Parameter(
self.init_gobal_subspace(
subspace_data).clone().detach().requires_grad_(True))
self.init_gobal_subspace(subspace_data, subspace_size)
else:
self.subspaces = None
@ -125,13 +126,17 @@ class GTLVQ(nn.Module):
def init_gobal_subspace(self, data, num_subspaces):
_, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
return subspace[:, :num_subspaces]
subspaces = subspace[:, :num_subspaces]
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def init_local_subspace(self, data):
_, _, v = torch.svd(data)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
return inital_projector.unsqueeze(0).repeat_interleave(
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
self.num_protos, 0)
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def global_tangent_distances(self, x):
# Tangent Projection
@ -154,13 +159,11 @@ class GTLVQ(nn.Module):
# Origin Author:
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
_, proto_int_shape = _int_and_mixed_shape(protos)
# check if the shapes are correct
_check_shapes(signal_int_shape, proto_int_shape)
atom_axes = list(range(3, len(signal_int_shape)))
# Tangent Data Projections
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
data = signals.squeeze(2).permute([1, 0, 2])
@ -170,7 +173,7 @@ class GTLVQ(nn.Module):
projected_diff = torch.reshape(
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
signal_shape[3:])
diss = torch.norm(projected_diff,2,dim=-1)
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
def get_parameters(self):