Automatic Formatting.

This commit is contained in:
Alexander Engelsberger
2021-04-23 17:24:53 +02:00
parent e1d56595c1
commit 7c30ffe2c7
28 changed files with 393 additions and 321 deletions

View File

@@ -3,5 +3,5 @@
from .prototypes import Prototypes1D
__all__ = [
'Prototypes1D',
"Prototypes1D",
]

View File

@@ -7,7 +7,7 @@ from prototorch.functions.losses import glvq_loss
class GLVQLoss(torch.nn.Module):
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
super().__init__(**kwargs)
self.margin = margin
self.squashing = get_activation(squashing)
@@ -37,4 +37,4 @@ class NeuralGasEnergy(torch.nn.Module):
@staticmethod
def _nghood_fn(rankings, lm):
return torch.exp(-rankings / lm)
return torch.exp(-rankings / lm)

View File

@@ -1,9 +1,11 @@
from torch import nn
import torch
from prototorch.modules.prototypes import Prototypes1D
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from torch import nn
from prototorch.functions.distances import (euclidean_distance_matrix,
tangent_distance)
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D
class GTLVQ(nn.Module):
@@ -71,7 +73,7 @@ class GTLVQ(nn.Module):
subspace_data=None,
prototype_data=None,
subspace_size=256,
tangent_projection_type='local',
tangent_projection_type="local",
prototypes_per_class=2,
feature_dim=256,
):
@@ -82,37 +84,39 @@ class GTLVQ(nn.Module):
self.feature_dim = feature_dim
if subspace_data is None:
raise ValueError('Init Data must be specified!')
raise ValueError("Init Data must be specified!")
self.tpt = tangent_projection_type
with torch.no_grad():
if self.tpt == 'local' or self.tpt == 'local_proj':
if self.tpt == "local" or self.tpt == "local_proj":
self.init_local_subspace(subspace_data)
elif self.tpt == 'global':
elif self.tpt == "global":
self.init_gobal_subspace(subspace_data, subspace_size)
else:
self.subspaces = None
# Hypothesis-Margin-Classifier
self.cls = Prototypes1D(input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
nclasses=num_classes,
prototype_initializer='stratified_mean',
data=prototype_data)
self.cls = Prototypes1D(
input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
nclasses=num_classes,
prototype_initializer="stratified_mean",
data=prototype_data,
)
def forward(self, x):
# Tangent Projection
if self.tpt == 'local_proj':
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2)
if self.tpt == "local_proj":
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2))
dis, proj_x = self.local_tangent_projection(x_conform)
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
self.feature_dim)
return proj_x, dis
elif self.tpt == "local":
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2)
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2))
dis = tangent_distance(x_conform, self.cls.prototypes,
self.subspaces)
elif self.tpt == "gloabl":
@@ -127,25 +131,27 @@ class GTLVQ(nn.Module):
_, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = subspace[:, :num_subspaces]
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
self.subspaces = (torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True))
def init_local_subspace(self, data):
_, _, v = torch.svd(data)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
self.num_protos, 0)
self.subspaces = torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True)
self.subspaces = (torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True))
def global_tangent_distances(self, x):
# Tangent Projection
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
x, projected_prototypes = (
x @ self.subspaces,
self.cls.prototypes @ self.subspaces,
)
# Euclidean Distance
return euclidean_distance_matrix(x, projected_prototypes)
def local_tangent_projection(self,
signals):
def local_tangent_projection(self, signals):
# Note: subspaces is always assumed as transposed and must be orthogonal!
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
@@ -183,8 +189,7 @@ class GTLVQ(nn.Module):
def orthogonalize_subspace(self):
if self.subspaces is not None:
with torch.no_grad():
ortho_subpsaces = orthogonalization(
self.subspaces
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
self.subspaces)
ortho_subpsaces = (orthogonalization(self.subspaces)
if self.tpt == "global" else
torch.nn.init.orthogonal_(self.subspaces))
self.subspaces.copy_(ortho_subpsaces)

View File

@@ -29,14 +29,16 @@ class Prototypes1D(_Prototypes):
TODO Complete this doc-string.
"""
def __init__(self,
prototypes_per_class=1,
prototype_initializer="ones",
prototype_distribution=None,
data=None,
dtype=torch.float32,
one_hot_labels=False,
**kwargs):
def __init__(
self,
prototypes_per_class=1,
prototype_initializer="ones",
prototype_distribution=None,
data=None,
dtype=torch.float32,
one_hot_labels=False,
**kwargs,
):
# Convert tensors to python lists before processing
if prototype_distribution is not None: