196 lines
7.6 KiB
Python
196 lines
7.6 KiB
Python
import torch
|
||
from torch import nn
|
||
|
||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||
tangent_distance)
|
||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||
from prototorch.functions.normalization import orthogonalization
|
||
from prototorch.modules.prototypes import Prototypes1D
|
||
|
||
|
||
class GTLVQ(nn.Module):
|
||
r""" Generalized Tangent Learning Vector Quantization
|
||
|
||
Parameters
|
||
----------
|
||
num_classes: int
|
||
Number of classes of the given classification problem.
|
||
|
||
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
|
||
Subspace data for the point approximation, required
|
||
|
||
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
|
||
prototype data for initalization of the prototypes used in GTLVQ.
|
||
|
||
subspace_size: int (default=256,optional)
|
||
Subspace dimension of the Projectors. Currently only supported
|
||
with tagnent_projection_type=global.
|
||
|
||
tangent_projection_type: string
|
||
Specifies the tangent projection type
|
||
options: local
|
||
local_proj
|
||
global
|
||
local: computes the tangent distances without emphasizing projected
|
||
data. Only distances are available
|
||
local_proj: computs tangent distances and returns the projected data
|
||
for further use. Be careful: data is repeated by number of prototypes
|
||
global: Number of subspaces is set to one and every prototypes
|
||
uses the same.
|
||
|
||
prototypes_per_class: int (default=2,optional)
|
||
Number of prototypes per class
|
||
|
||
feature_dim: int (default=256)
|
||
Dimensionality of the feature space specified as integer.
|
||
Prototype dimension.
|
||
|
||
Notes
|
||
-----
|
||
The GTLVQ [1] is a prototype-based classification learning model. The
|
||
GTLVQ uses the Tangent-Distances for a local point approximation
|
||
of an assumed data manifold via prototypial representations.
|
||
|
||
The GTLVQ requires subspace projectors for transforming the data
|
||
and prototypes into the affine subspace. Every prototype is
|
||
equipped with a specific subpspace and represents a point
|
||
approximation of the assumed manifold.
|
||
|
||
In practice prototypes and data are projected on this manifold
|
||
and pairwise euclidean distance computes.
|
||
|
||
References
|
||
----------
|
||
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
|
||
in classification based on manifolc. models and its relation
|
||
to tangent metric learning. In: 2017 International Joint
|
||
Conference on Neural Networks (IJCNN).
|
||
Bd. 2017-May : IEEE, 2017, S. 1756–1765
|
||
"""
|
||
def __init__(
|
||
self,
|
||
num_classes,
|
||
subspace_data=None,
|
||
prototype_data=None,
|
||
subspace_size=256,
|
||
tangent_projection_type="local",
|
||
prototypes_per_class=2,
|
||
feature_dim=256,
|
||
):
|
||
super(GTLVQ, self).__init__()
|
||
|
||
self.num_protos = num_classes * prototypes_per_class
|
||
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
||
self.feature_dim = feature_dim
|
||
|
||
if subspace_data is None:
|
||
raise ValueError("Init Data must be specified!")
|
||
|
||
self.tpt = tangent_projection_type
|
||
with torch.no_grad():
|
||
if self.tpt == "local" or self.tpt == "local_proj":
|
||
self.init_local_subspace(subspace_data)
|
||
elif self.tpt == "global":
|
||
self.init_gobal_subspace(subspace_data, subspace_size)
|
||
else:
|
||
self.subspaces = None
|
||
|
||
# Hypothesis-Margin-Classifier
|
||
self.cls = Prototypes1D(
|
||
input_dim=feature_dim,
|
||
prototypes_per_class=prototypes_per_class,
|
||
nclasses=num_classes,
|
||
prototype_initializer="stratified_mean",
|
||
data=prototype_data,
|
||
)
|
||
|
||
def forward(self, x):
|
||
# Tangent Projection
|
||
if self.tpt == "local_proj":
|
||
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||
1).unsqueeze(2))
|
||
dis, proj_x = self.local_tangent_projection(x_conform)
|
||
|
||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
||
self.feature_dim)
|
||
return proj_x, dis
|
||
elif self.tpt == "local":
|
||
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||
1).unsqueeze(2))
|
||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
||
self.subspaces)
|
||
elif self.tpt == "gloabl":
|
||
dis = self.global_tangent_distances(x)
|
||
else:
|
||
dis = (x @ self.cls.prototypes.T) / (
|
||
torch.norm(x, dim=1, keepdim=True) @ torch.norm(
|
||
self.cls.prototypes, dim=1, keepdim=True).T)
|
||
return dis
|
||
|
||
def init_gobal_subspace(self, data, num_subspaces):
|
||
_, _, v = torch.svd(data)
|
||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||
subspaces = subspace[:, :num_subspaces]
|
||
self.subspaces = (torch.nn.Parameter(
|
||
subspaces).clone().detach().requires_grad_(True))
|
||
|
||
def init_local_subspace(self, data):
|
||
_, _, v = torch.svd(data)
|
||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
||
self.num_protos, 0)
|
||
self.subspaces = (torch.nn.Parameter(
|
||
subspaces).clone().detach().requires_grad_(True))
|
||
|
||
def global_tangent_distances(self, x):
|
||
# Tangent Projection
|
||
x, projected_prototypes = (
|
||
x @ self.subspaces,
|
||
self.cls.prototypes @ self.subspaces,
|
||
)
|
||
# Euclidean Distance
|
||
return euclidean_distance_matrix(x, projected_prototypes)
|
||
|
||
def local_tangent_projection(self, signals):
|
||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||
# subspace should be orthogonalized
|
||
# Origin Source Code
|
||
# Origin Author:
|
||
protos = self.cls.prototypes
|
||
subspaces = self.subspaces
|
||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||
_, proto_int_shape = _int_and_mixed_shape(protos)
|
||
|
||
# check if the shapes are correct
|
||
_check_shapes(signal_int_shape, proto_int_shape)
|
||
|
||
# Tangent Data Projections
|
||
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
|
||
data = signals.squeeze(2).permute([1, 0, 2])
|
||
projected_data = torch.bmm(data, subspaces)
|
||
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
|
||
diff = projected_data - projected_protos
|
||
projected_diff = torch.reshape(
|
||
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||
signal_shape[3:])
|
||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
|
||
|
||
def get_parameters(self):
|
||
return {
|
||
"params": self.cls.prototypes,
|
||
}, {
|
||
"params": self.subspaces
|
||
}
|
||
|
||
def orthogonalize_subspace(self):
|
||
if self.subspaces is not None:
|
||
with torch.no_grad():
|
||
ortho_subpsaces = (orthogonalization(self.subspaces)
|
||
if self.tpt == "global" else
|
||
torch.nn.init.orthogonal_(self.subspaces))
|
||
self.subspaces.copy_(ortho_subpsaces)
|