Automatic Formatting.
This commit is contained in:
@@ -3,5 +3,5 @@
|
||||
from .prototypes import Prototypes1D
|
||||
|
||||
__all__ = [
|
||||
'Prototypes1D',
|
||||
"Prototypes1D",
|
||||
]
|
||||
|
@@ -7,7 +7,7 @@ from prototorch.functions.losses import glvq_loss
|
||||
|
||||
|
||||
class GLVQLoss(torch.nn.Module):
|
||||
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
@@ -37,4 +37,4 @@ class NeuralGasEnergy(torch.nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, lm):
|
||||
return torch.exp(-rankings / lm)
|
||||
return torch.exp(-rankings / lm)
|
||||
|
@@ -1,9 +1,11 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from torch import nn
|
||||
|
||||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||||
tangent_distance)
|
||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
@@ -71,7 +73,7 @@ class GTLVQ(nn.Module):
|
||||
subspace_data=None,
|
||||
prototype_data=None,
|
||||
subspace_size=256,
|
||||
tangent_projection_type='local',
|
||||
tangent_projection_type="local",
|
||||
prototypes_per_class=2,
|
||||
feature_dim=256,
|
||||
):
|
||||
@@ -82,37 +84,39 @@ class GTLVQ(nn.Module):
|
||||
self.feature_dim = feature_dim
|
||||
|
||||
if subspace_data is None:
|
||||
raise ValueError('Init Data must be specified!')
|
||||
raise ValueError("Init Data must be specified!")
|
||||
|
||||
self.tpt = tangent_projection_type
|
||||
with torch.no_grad():
|
||||
if self.tpt == 'local' or self.tpt == 'local_proj':
|
||||
if self.tpt == "local" or self.tpt == "local_proj":
|
||||
self.init_local_subspace(subspace_data)
|
||||
elif self.tpt == 'global':
|
||||
elif self.tpt == "global":
|
||||
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||
else:
|
||||
self.subspaces = None
|
||||
|
||||
# Hypothesis-Margin-Classifier
|
||||
self.cls = Prototypes1D(input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=prototype_data)
|
||||
self.cls = Prototypes1D(
|
||||
input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=prototype_data,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Tangent Projection
|
||||
if self.tpt == 'local_proj':
|
||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2)
|
||||
if self.tpt == "local_proj":
|
||||
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2))
|
||||
dis, proj_x = self.local_tangent_projection(x_conform)
|
||||
|
||||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
||||
self.feature_dim)
|
||||
return proj_x, dis
|
||||
elif self.tpt == "local":
|
||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2)
|
||||
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2))
|
||||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
||||
self.subspaces)
|
||||
elif self.tpt == "gloabl":
|
||||
@@ -127,25 +131,27 @@ class GTLVQ(nn.Module):
|
||||
_, _, v = torch.svd(data)
|
||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
subspaces = subspace[:, :num_subspaces]
|
||||
self.subspaces = torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True)
|
||||
self.subspaces = (torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True))
|
||||
|
||||
def init_local_subspace(self, data):
|
||||
_, _, v = torch.svd(data)
|
||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
||||
self.num_protos, 0)
|
||||
self.subspaces = torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True)
|
||||
self.subspaces = (torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True))
|
||||
|
||||
def global_tangent_distances(self, x):
|
||||
# Tangent Projection
|
||||
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
|
||||
x, projected_prototypes = (
|
||||
x @ self.subspaces,
|
||||
self.cls.prototypes @ self.subspaces,
|
||||
)
|
||||
# Euclidean Distance
|
||||
return euclidean_distance_matrix(x, projected_prototypes)
|
||||
|
||||
def local_tangent_projection(self,
|
||||
signals):
|
||||
def local_tangent_projection(self, signals):
|
||||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
||||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
@@ -183,8 +189,7 @@ class GTLVQ(nn.Module):
|
||||
def orthogonalize_subspace(self):
|
||||
if self.subspaces is not None:
|
||||
with torch.no_grad():
|
||||
ortho_subpsaces = orthogonalization(
|
||||
self.subspaces
|
||||
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
|
||||
self.subspaces)
|
||||
ortho_subpsaces = (orthogonalization(self.subspaces)
|
||||
if self.tpt == "global" else
|
||||
torch.nn.init.orthogonal_(self.subspaces))
|
||||
self.subspaces.copy_(ortho_subpsaces)
|
||||
|
@@ -29,14 +29,16 @@ class Prototypes1D(_Prototypes):
|
||||
|
||||
TODO Complete this doc-string.
|
||||
"""
|
||||
def __init__(self,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="ones",
|
||||
prototype_distribution=None,
|
||||
data=None,
|
||||
dtype=torch.float32,
|
||||
one_hot_labels=False,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="ones",
|
||||
prototype_distribution=None,
|
||||
data=None,
|
||||
dtype=torch.float32,
|
||||
one_hot_labels=False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
# Convert tensors to python lists before processing
|
||||
if prototype_distribution is not None:
|
||||
|
Reference in New Issue
Block a user