Alexander Engelsberger 7c30ffe2c7 Automatic Formatting.
2021-04-23 17:25:23 +02:00

196 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
from prototorch.functions.distances import (euclidean_distance_matrix,
tangent_distance)
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D
class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization
Parameters
----------
num_classes: int
Number of classes of the given classification problem.
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
Subspace data for the point approximation, required
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
prototype data for initalization of the prototypes used in GTLVQ.
subspace_size: int (default=256,optional)
Subspace dimension of the Projectors. Currently only supported
with tagnent_projection_type=global.
tangent_projection_type: string
Specifies the tangent projection type
options: local
local_proj
global
local: computes the tangent distances without emphasizing projected
data. Only distances are available
local_proj: computs tangent distances and returns the projected data
for further use. Be careful: data is repeated by number of prototypes
global: Number of subspaces is set to one and every prototypes
uses the same.
prototypes_per_class: int (default=2,optional)
Number of prototypes per class
feature_dim: int (default=256)
Dimensionality of the feature space specified as integer.
Prototype dimension.
Notes
-----
The GTLVQ [1] is a prototype-based classification learning model. The
GTLVQ uses the Tangent-Distances for a local point approximation
of an assumed data manifold via prototypial representations.
The GTLVQ requires subspace projectors for transforming the data
and prototypes into the affine subspace. Every prototype is
equipped with a specific subpspace and represents a point
approximation of the assumed manifold.
In practice prototypes and data are projected on this manifold
and pairwise euclidean distance computes.
References
----------
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
in classification based on manifolc. models and its relation
to tangent metric learning. In: 2017 International Joint
Conference on Neural Networks (IJCNN).
Bd. 2017-May : IEEE, 2017, S. 17561765
"""
def __init__(
self,
num_classes,
subspace_data=None,
prototype_data=None,
subspace_size=256,
tangent_projection_type="local",
prototypes_per_class=2,
feature_dim=256,
):
super(GTLVQ, self).__init__()
self.num_protos = num_classes * prototypes_per_class
self.subspace_size = feature_dim if subspace_size is None else subspace_size
self.feature_dim = feature_dim
if subspace_data is None:
raise ValueError("Init Data must be specified!")
self.tpt = tangent_projection_type
with torch.no_grad():
if self.tpt == "local" or self.tpt == "local_proj":
self.init_local_subspace(subspace_data)
elif self.tpt == "global":
self.init_gobal_subspace(subspace_data, subspace_size)
else:
self.subspaces = None
# Hypothesis-Margin-Classifier
self.cls = Prototypes1D(
input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
nclasses=num_classes,
prototype_initializer="stratified_mean",
data=prototype_data,
)
def forward(self, x):
# Tangent Projection
if self.tpt == "local_proj":
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2))
dis, proj_x = self.local_tangent_projection(x_conform)
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
self.feature_dim)
return proj_x, dis
elif self.tpt == "local":
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2))
dis = tangent_distance(x_conform, self.cls.prototypes,
self.subspaces)
elif self.tpt == "gloabl":
dis = self.global_tangent_distances(x)
else:
dis = (x @ self.cls.prototypes.T) / (
torch.norm(x, dim=1, keepdim=True) @ torch.norm(
self.cls.prototypes, dim=1, keepdim=True).T)
return dis
def init_gobal_subspace(self, data, num_subspaces):
_, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = subspace[:, :num_subspaces]
self.subspaces = (torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True))
def init_local_subspace(self, data):
_, _, v = torch.svd(data)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
self.num_protos, 0)
self.subspaces = (torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True))
def global_tangent_distances(self, x):
# Tangent Projection
x, projected_prototypes = (
x @ self.subspaces,
self.cls.prototypes @ self.subspaces,
)
# Euclidean Distance
return euclidean_distance_matrix(x, projected_prototypes)
def local_tangent_projection(self, signals):
# Note: subspaces is always assumed as transposed and must be orthogonal!
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
# subspace should be orthogonalized
# Origin Source Code
# Origin Author:
protos = self.cls.prototypes
subspaces = self.subspaces
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
_, proto_int_shape = _int_and_mixed_shape(protos)
# check if the shapes are correct
_check_shapes(signal_int_shape, proto_int_shape)
# Tangent Data Projections
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
data = signals.squeeze(2).permute([1, 0, 2])
projected_data = torch.bmm(data, subspaces)
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
diff = projected_data - projected_protos
projected_diff = torch.reshape(
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
signal_shape[3:])
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
def get_parameters(self):
return {
"params": self.cls.prototypes,
}, {
"params": self.subspaces
}
def orthogonalize_subspace(self):
if self.subspaces is not None:
with torch.no_grad():
ortho_subpsaces = (orthogonalization(self.subspaces)
if self.tpt == "global" else
torch.nn.init.orthogonal_(self.subspaces))
self.subspaces.copy_(ortho_subpsaces)