Codacy Bug Report fixes

This commit is contained in:
Christoph 2021-01-14 10:04:43 +01:00
parent 895281aabd
commit 30dc0ea8b1
5 changed files with 36 additions and 36 deletions

View File

@ -1,16 +1,13 @@
"""
ProtoTorch GTLVQ example using MNIST data.
The GTLVQ is placed as an classification model on
The GTLVQ is placed as an classification model on
top of a CNN, considered as featurer extractor.
Initialization of subpsace and prototypes in
Initialization of subpsace and prototypes in
Siamnese fashion
For more info about GTLVQ see:
DOI:10.1109/IJCNN.2016.7727534
"""
import sys
from torch.nn import parameter
from matplotlib.pyplot import fill
import numpy as np
import torch
import torch.nn as nn

View File

@ -77,10 +77,10 @@ def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
r""" Computes an euclidean distanes matrix given two distinct vectors.
last dimension must be the vector dimension!
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
x.shape = (number_of_x_vectors, vector_dim)
y.shape = (number_of_y_vectors, vector_dim)
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
"""
for tensor in [x, y]:

View File

@ -15,7 +15,7 @@ def calculate_prototype_accuracy(y_pred, y_true, plabels):
def predict_label(y_pred, plabels):
r""" Predicts labels given a prediction of a prototype based model.
r""" Predicts labels given a prediction of a prototype based model.
"""
with torch.no_grad():
return plabels[torch.argmin(y_pred, 1)]

View File

@ -24,7 +24,7 @@ def orthogonalization(tensors):
return out
def trace_normalization(tensors, epsilon=[1e-10]):
def trace_normalization(tensors):
r""" Trace normalization
"""
epsilon = torch.tensor([1e-10], dtype=torch.float64)

View File

@ -3,14 +3,15 @@ import torch
from prototorch.modules.prototypes import Prototypes1D
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from prototorch.functions.helper import _check_shapes,_int_and_mixed_shape
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization
Parameters
----------
num_classes: int
num_classes: int
Number of classes of the given classification problem.
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
@ -19,7 +20,11 @@ class GTLVQ(nn.Module):
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
prototype data for initalization of the prototypes used in GTLVQ.
tangent_projection_type: string
subspace_size: int (default=256,optional)
Subspace dimension of the Projectors. Currently only supported
with tagnent_projection_type=global.
tangent_projection_type: string
Specifies the tangent projection type
options: local
local_proj
@ -28,33 +33,33 @@ class GTLVQ(nn.Module):
data. Only distances are available
local_proj: computs tangent distances and returns the projected data
for further use. Be careful: data is repeated by number of prototypes
global: Number of subspaces is set to one and every prototypes
uses the same.
global: Number of subspaces is set to one and every prototypes
uses the same.
prototypes_per_class: int (default=2,optional)
Number of prototypes per class
feature_dim: int (default=256)
Dimensionality of the feature space specified as integer.
Prototype dimension.
Dimensionality of the feature space specified as integer.
Prototype dimension.
Notes
-----
The GTLVQ [1] is a prototype-based classification learning model. The
GTLVQ uses the Tangent-Distances for a local point approximation
of an assumed data manifold via prototypial representations.
The GTLVQ [1] is a prototype-based classification learning model. The
GTLVQ uses the Tangent-Distances for a local point approximation
of an assumed data manifold via prototypial representations.
The GTLVQ requires subspace projectors for transforming the data
and prototypes into the affine subspace. Every prototype is
equipped with a specific subpspace and represents a point
and prototypes into the affine subspace. Every prototype is
equipped with a specific subpspace and represents a point
approximation of the assumed manifold.
In practice prototypes and data are projected on this manifold
In practice prototypes and data are projected on this manifold
and pairwise euclidean distance computes.
References
----------
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
in classification based on manifolc. models and its relation
to tangent metric learning. In: 2017 International Joint
Conference on Neural Networks (IJCNN).
@ -82,13 +87,9 @@ class GTLVQ(nn.Module):
self.tpt = tangent_projection_type
with torch.no_grad():
if self.tpt == 'local' or self.tpt == 'local_proj':
self.subspaces = torch.nn.Parameter(
self.init_local_subspace(
subspace_data).clone().detach().requires_grad_(True))
self.init_local_subspace(subspace_data)
elif self.tpt == 'global':
self.subspaces = torch.nn.Parameter(
self.init_gobal_subspace(
subspace_data).clone().detach().requires_grad_(True))
self.init_gobal_subspace(subspace_data, subspace_size)
else:
self.subspaces = None
@ -125,13 +126,17 @@ class GTLVQ(nn.Module):
def init_gobal_subspace(self, data, num_subspaces):
_, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
return subspace[:, :num_subspaces]
subspaces = subspace[:, :num_subspaces]
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def init_local_subspace(self, data):
_, _, v = torch.svd(data)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
return inital_projector.unsqueeze(0).repeat_interleave(
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
self.num_protos, 0)
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
def global_tangent_distances(self, x):
# Tangent Projection
@ -154,13 +159,11 @@ class GTLVQ(nn.Module):
# Origin Author:
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
_, proto_int_shape = _int_and_mixed_shape(protos)
# check if the shapes are correct
_check_shapes(signal_int_shape, proto_int_shape)
atom_axes = list(range(3, len(signal_int_shape)))
# Tangent Data Projections
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
data = signals.squeeze(2).permute([1, 0, 2])
@ -170,7 +173,7 @@ class GTLVQ(nn.Module):
projected_diff = torch.reshape(
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
signal_shape[3:])
diss = torch.norm(projected_diff,2,dim=-1)
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
def get_parameters(self):