Codacy Bug Report fixes
This commit is contained in:
parent
895281aabd
commit
30dc0ea8b1
@ -1,16 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
ProtoTorch GTLVQ example using MNIST data.
|
ProtoTorch GTLVQ example using MNIST data.
|
||||||
The GTLVQ is placed as an classification model on
|
The GTLVQ is placed as an classification model on
|
||||||
top of a CNN, considered as featurer extractor.
|
top of a CNN, considered as featurer extractor.
|
||||||
Initialization of subpsace and prototypes in
|
Initialization of subpsace and prototypes in
|
||||||
Siamnese fashion
|
Siamnese fashion
|
||||||
For more info about GTLVQ see:
|
For more info about GTLVQ see:
|
||||||
DOI:10.1109/IJCNN.2016.7727534
|
DOI:10.1109/IJCNN.2016.7727534
|
||||||
"""
|
"""
|
||||||
import sys
|
|
||||||
|
|
||||||
from torch.nn import parameter
|
|
||||||
from matplotlib.pyplot import fill
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -77,10 +77,10 @@ def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
|||||||
r""" Computes an euclidean distanes matrix given two distinct vectors.
|
r""" Computes an euclidean distanes matrix given two distinct vectors.
|
||||||
last dimension must be the vector dimension!
|
last dimension must be the vector dimension!
|
||||||
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
||||||
|
|
||||||
x.shape = (number_of_x_vectors, vector_dim)
|
x.shape = (number_of_x_vectors, vector_dim)
|
||||||
y.shape = (number_of_y_vectors, vector_dim)
|
y.shape = (number_of_y_vectors, vector_dim)
|
||||||
|
|
||||||
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
||||||
"""
|
"""
|
||||||
for tensor in [x, y]:
|
for tensor in [x, y]:
|
||||||
|
@ -15,7 +15,7 @@ def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
|||||||
|
|
||||||
|
|
||||||
def predict_label(y_pred, plabels):
|
def predict_label(y_pred, plabels):
|
||||||
r""" Predicts labels given a prediction of a prototype based model.
|
r""" Predicts labels given a prediction of a prototype based model.
|
||||||
"""
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return plabels[torch.argmin(y_pred, 1)]
|
return plabels[torch.argmin(y_pred, 1)]
|
||||||
|
@ -24,7 +24,7 @@ def orthogonalization(tensors):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def trace_normalization(tensors, epsilon=[1e-10]):
|
def trace_normalization(tensors):
|
||||||
r""" Trace normalization
|
r""" Trace normalization
|
||||||
"""
|
"""
|
||||||
epsilon = torch.tensor([1e-10], dtype=torch.float64)
|
epsilon = torch.tensor([1e-10], dtype=torch.float64)
|
||||||
|
@ -3,14 +3,15 @@ import torch
|
|||||||
from prototorch.modules.prototypes import Prototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
||||||
from prototorch.functions.normalization import orthogonalization
|
from prototorch.functions.normalization import orthogonalization
|
||||||
from prototorch.functions.helper import _check_shapes,_int_and_mixed_shape
|
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||||
|
|
||||||
|
|
||||||
class GTLVQ(nn.Module):
|
class GTLVQ(nn.Module):
|
||||||
r""" Generalized Tangent Learning Vector Quantization
|
r""" Generalized Tangent Learning Vector Quantization
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
num_classes: int
|
num_classes: int
|
||||||
Number of classes of the given classification problem.
|
Number of classes of the given classification problem.
|
||||||
|
|
||||||
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
|
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
|
||||||
@ -19,7 +20,11 @@ class GTLVQ(nn.Module):
|
|||||||
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
|
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
|
||||||
prototype data for initalization of the prototypes used in GTLVQ.
|
prototype data for initalization of the prototypes used in GTLVQ.
|
||||||
|
|
||||||
tangent_projection_type: string
|
subspace_size: int (default=256,optional)
|
||||||
|
Subspace dimension of the Projectors. Currently only supported
|
||||||
|
with tagnent_projection_type=global.
|
||||||
|
|
||||||
|
tangent_projection_type: string
|
||||||
Specifies the tangent projection type
|
Specifies the tangent projection type
|
||||||
options: local
|
options: local
|
||||||
local_proj
|
local_proj
|
||||||
@ -28,33 +33,33 @@ class GTLVQ(nn.Module):
|
|||||||
data. Only distances are available
|
data. Only distances are available
|
||||||
local_proj: computs tangent distances and returns the projected data
|
local_proj: computs tangent distances and returns the projected data
|
||||||
for further use. Be careful: data is repeated by number of prototypes
|
for further use. Be careful: data is repeated by number of prototypes
|
||||||
global: Number of subspaces is set to one and every prototypes
|
global: Number of subspaces is set to one and every prototypes
|
||||||
uses the same.
|
uses the same.
|
||||||
|
|
||||||
prototypes_per_class: int (default=2,optional)
|
prototypes_per_class: int (default=2,optional)
|
||||||
Number of prototypes per class
|
Number of prototypes per class
|
||||||
|
|
||||||
feature_dim: int (default=256)
|
feature_dim: int (default=256)
|
||||||
Dimensionality of the feature space specified as integer.
|
Dimensionality of the feature space specified as integer.
|
||||||
Prototype dimension.
|
Prototype dimension.
|
||||||
|
|
||||||
Notes
|
Notes
|
||||||
-----
|
-----
|
||||||
The GTLVQ [1] is a prototype-based classification learning model. The
|
The GTLVQ [1] is a prototype-based classification learning model. The
|
||||||
GTLVQ uses the Tangent-Distances for a local point approximation
|
GTLVQ uses the Tangent-Distances for a local point approximation
|
||||||
of an assumed data manifold via prototypial representations.
|
of an assumed data manifold via prototypial representations.
|
||||||
|
|
||||||
The GTLVQ requires subspace projectors for transforming the data
|
The GTLVQ requires subspace projectors for transforming the data
|
||||||
and prototypes into the affine subspace. Every prototype is
|
and prototypes into the affine subspace. Every prototype is
|
||||||
equipped with a specific subpspace and represents a point
|
equipped with a specific subpspace and represents a point
|
||||||
approximation of the assumed manifold.
|
approximation of the assumed manifold.
|
||||||
|
|
||||||
In practice prototypes and data are projected on this manifold
|
In practice prototypes and data are projected on this manifold
|
||||||
and pairwise euclidean distance computes.
|
and pairwise euclidean distance computes.
|
||||||
|
|
||||||
References
|
References
|
||||||
----------
|
----------
|
||||||
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
|
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
|
||||||
in classification based on manifolc. models and its relation
|
in classification based on manifolc. models and its relation
|
||||||
to tangent metric learning. In: 2017 International Joint
|
to tangent metric learning. In: 2017 International Joint
|
||||||
Conference on Neural Networks (IJCNN).
|
Conference on Neural Networks (IJCNN).
|
||||||
@ -82,13 +87,9 @@ class GTLVQ(nn.Module):
|
|||||||
self.tpt = tangent_projection_type
|
self.tpt = tangent_projection_type
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.tpt == 'local' or self.tpt == 'local_proj':
|
if self.tpt == 'local' or self.tpt == 'local_proj':
|
||||||
self.subspaces = torch.nn.Parameter(
|
self.init_local_subspace(subspace_data)
|
||||||
self.init_local_subspace(
|
|
||||||
subspace_data).clone().detach().requires_grad_(True))
|
|
||||||
elif self.tpt == 'global':
|
elif self.tpt == 'global':
|
||||||
self.subspaces = torch.nn.Parameter(
|
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||||
self.init_gobal_subspace(
|
|
||||||
subspace_data).clone().detach().requires_grad_(True))
|
|
||||||
else:
|
else:
|
||||||
self.subspaces = None
|
self.subspaces = None
|
||||||
|
|
||||||
@ -125,13 +126,17 @@ class GTLVQ(nn.Module):
|
|||||||
def init_gobal_subspace(self, data, num_subspaces):
|
def init_gobal_subspace(self, data, num_subspaces):
|
||||||
_, _, v = torch.svd(data)
|
_, _, v = torch.svd(data)
|
||||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||||
return subspace[:, :num_subspaces]
|
subspaces = subspace[:, :num_subspaces]
|
||||||
|
self.subspaces = torch.nn.Parameter(
|
||||||
|
subspaces).clone().detach().requires_grad_(True)
|
||||||
|
|
||||||
def init_local_subspace(self, data):
|
def init_local_subspace(self, data):
|
||||||
_, _, v = torch.svd(data)
|
_, _, v = torch.svd(data)
|
||||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||||
return inital_projector.unsqueeze(0).repeat_interleave(
|
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
||||||
self.num_protos, 0)
|
self.num_protos, 0)
|
||||||
|
self.subspaces = torch.nn.Parameter(
|
||||||
|
subspaces).clone().detach().requires_grad_(True)
|
||||||
|
|
||||||
def global_tangent_distances(self, x):
|
def global_tangent_distances(self, x):
|
||||||
# Tangent Projection
|
# Tangent Projection
|
||||||
@ -154,13 +159,11 @@ class GTLVQ(nn.Module):
|
|||||||
# Origin Author:
|
# Origin Author:
|
||||||
|
|
||||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||||
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
_, proto_int_shape = _int_and_mixed_shape(protos)
|
||||||
|
|
||||||
# check if the shapes are correct
|
# check if the shapes are correct
|
||||||
_check_shapes(signal_int_shape, proto_int_shape)
|
_check_shapes(signal_int_shape, proto_int_shape)
|
||||||
|
|
||||||
atom_axes = list(range(3, len(signal_int_shape)))
|
|
||||||
|
|
||||||
# Tangent Data Projections
|
# Tangent Data Projections
|
||||||
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
|
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
|
||||||
data = signals.squeeze(2).permute([1, 0, 2])
|
data = signals.squeeze(2).permute([1, 0, 2])
|
||||||
@ -170,7 +173,7 @@ class GTLVQ(nn.Module):
|
|||||||
projected_diff = torch.reshape(
|
projected_diff = torch.reshape(
|
||||||
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||||
signal_shape[3:])
|
signal_shape[3:])
|
||||||
diss = torch.norm(projected_diff,2,dim=-1)
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
|
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
|
||||||
|
|
||||||
def get_parameters(self):
|
def get_parameters(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user