gtlvq
This commit is contained in:
190
prototorch/modules/models.py
Normal file
190
prototorch/modules/models.py
Normal file
@@ -0,0 +1,190 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.functions.helper import _check_shapes,_int_and_mixed_shape
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
r""" Generalized Tangent Learning Vector Quantization
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes: int
|
||||
Number of classes of the given classification problem.
|
||||
|
||||
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
|
||||
Subspace data for the point approximation, required
|
||||
|
||||
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
|
||||
prototype data for initalization of the prototypes used in GTLVQ.
|
||||
|
||||
tangent_projection_type: string
|
||||
Specifies the tangent projection type
|
||||
options: local
|
||||
local_proj
|
||||
global
|
||||
local: computes the tangent distances without emphasizing projected
|
||||
data. Only distances are available
|
||||
local_proj: computs tangent distances and returns the projected data
|
||||
for further use. Be careful: data is repeated by number of prototypes
|
||||
global: Number of subspaces is set to one and every prototypes
|
||||
uses the same.
|
||||
|
||||
prototypes_per_class: int (default=2,optional)
|
||||
Number of prototypes per class
|
||||
|
||||
feature_dim: int (default=256)
|
||||
Dimensionality of the feature space specified as integer.
|
||||
Prototype dimension.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The GTLVQ [1] is a prototype-based classification learning model. The
|
||||
GTLVQ uses the Tangent-Distances for a local point approximation
|
||||
of an assumed data manifold via prototypial representations.
|
||||
|
||||
The GTLVQ requires subspace projectors for transforming the data
|
||||
and prototypes into the affine subspace. Every prototype is
|
||||
equipped with a specific subpspace and represents a point
|
||||
approximation of the assumed manifold.
|
||||
|
||||
In practice prototypes and data are projected on this manifold
|
||||
and pairwise euclidean distance computes.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
|
||||
in classification based on manifolc. models and its relation
|
||||
to tangent metric learning. In: 2017 International Joint
|
||||
Conference on Neural Networks (IJCNN).
|
||||
Bd. 2017-May : IEEE, 2017, S. 1756–1765
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
subspace_data=None,
|
||||
prototype_data=None,
|
||||
subspace_size=256,
|
||||
tangent_projection_type='local',
|
||||
prototypes_per_class=2,
|
||||
feature_dim=256,
|
||||
):
|
||||
super(GTLVQ, self).__init__()
|
||||
|
||||
self.num_protos = num_classes * prototypes_per_class
|
||||
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
||||
self.feature_dim = feature_dim
|
||||
|
||||
if subspace_data is None:
|
||||
raise ValueError('Init Data must be specified!')
|
||||
|
||||
self.tpt = tangent_projection_type
|
||||
with torch.no_grad():
|
||||
if self.tpt == 'local' or self.tpt == 'local_proj':
|
||||
self.subspaces = torch.nn.Parameter(
|
||||
self.init_local_subspace(
|
||||
subspace_data).clone().detach().requires_grad_(True))
|
||||
elif self.tpt == 'global':
|
||||
self.subspaces = torch.nn.Parameter(
|
||||
self.init_gobal_subspace(
|
||||
subspace_data).clone().detach().requires_grad_(True))
|
||||
else:
|
||||
self.subspaces = None
|
||||
|
||||
# Hypothesis-Margin-Classifier
|
||||
self.cls = Prototypes1D(input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=prototype_data)
|
||||
|
||||
def forward(self, x):
|
||||
# Tangent Projection
|
||||
if self.tpt == 'local_proj':
|
||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2)
|
||||
dis, proj_x = self.local_tangent_projection(
|
||||
x_conform, self.cls.prototypes, self.subspaces)
|
||||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
||||
self.feature_dim)
|
||||
return proj_x, dis
|
||||
elif self.tpt == "local":
|
||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2)
|
||||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
||||
self.subspaces)
|
||||
elif self.tpt == "gloabl":
|
||||
dis = self.global_tangent_distances(x)
|
||||
else:
|
||||
dis = (x @ self.cls.prototypes.T) / (
|
||||
torch.norm(x, dim=1, keepdim=True) @ torch.norm(
|
||||
self.cls.prototypes, dim=1, keepdim=True).T)
|
||||
return dis
|
||||
|
||||
def init_gobal_subspace(self, data, num_subspaces):
|
||||
_, _, v = torch.svd(data)
|
||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
return subspace[:, :num_subspaces]
|
||||
|
||||
def init_local_subspace(self, data):
|
||||
_, _, v = torch.svd(data)
|
||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
return inital_projector.unsqueeze(0).repeat_interleave(
|
||||
self.num_protos, 0)
|
||||
|
||||
def global_tangent_distances(self, x):
|
||||
# Tangent Projection
|
||||
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
|
||||
# Euclidean Distance
|
||||
return euclidean_distance_matrix(x, projected_prototypes)
|
||||
|
||||
def local_tangent_projection(self,
|
||||
signals,
|
||||
protos,
|
||||
subspaces,
|
||||
squared=False,
|
||||
epsilon=1e-10):
|
||||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
||||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||
# subspace should be orthogonalized
|
||||
# Origin Source Code
|
||||
# Origin Author:
|
||||
|
||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
||||
|
||||
# check if the shapes are correct
|
||||
_check_shapes(signal_int_shape, proto_int_shape)
|
||||
|
||||
atom_axes = list(range(3, len(signal_int_shape)))
|
||||
|
||||
# Tangent Data Projections
|
||||
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
|
||||
data = signals.squeeze(2).permute([1, 0, 2])
|
||||
projected_data = torch.bmm(data, subspaces)
|
||||
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
|
||||
diff = projected_data - projected_protos
|
||||
projected_diff = torch.reshape(
|
||||
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||
signal_shape[3:])
|
||||
diss = torch.norm(projected_diff,2,dim=-1)
|
||||
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
|
||||
|
||||
def get_parameters(self):
|
||||
return {
|
||||
"params": self.cls.prototypes,
|
||||
}, {
|
||||
"params": self.subspaces
|
||||
}
|
||||
|
||||
def orthogonalize_subspace(self):
|
||||
if self.subspaces is not None:
|
||||
with torch.no_grad():
|
||||
ortho_subpsaces = orthogonalization(
|
||||
self.subspaces
|
||||
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
|
||||
self.subspaces)
|
||||
self.subspaces.copy_(ortho_subpsaces)
|
Reference in New Issue
Block a user