gtlvq
This commit is contained in:
parent
a55320a65b
commit
895281aabd
165
examples/gtlvq_mnist.py
Normal file
165
examples/gtlvq_mnist.py
Normal file
@ -0,0 +1,165 @@
|
||||
"""
|
||||
ProtoTorch GTLVQ example using MNIST data.
|
||||
The GTLVQ is placed as an classification model on
|
||||
top of a CNN, considered as featurer extractor.
|
||||
Initialization of subpsace and prototypes in
|
||||
Siamnese fashion
|
||||
For more info about GTLVQ see:
|
||||
DOI:10.1109/IJCNN.2016.7727534
|
||||
"""
|
||||
import sys
|
||||
|
||||
from torch.nn import parameter
|
||||
from matplotlib.pyplot import fill
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||
from prototorch.modules.models import GTLVQ
|
||||
|
||||
# Parameters and options
|
||||
n_epochs = 50
|
||||
batch_size_train = 64
|
||||
batch_size_test = 1000
|
||||
learning_rate = 0.1
|
||||
momentum = 0.5
|
||||
log_interval = 10
|
||||
cuda = "cuda:1"
|
||||
random_seed = 1
|
||||
device = torch.device(cuda if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Configures reproducability
|
||||
torch.manual_seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
# Prepare and preprocess the data
|
||||
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
||||
'./files/',
|
||||
train=True,
|
||||
download=True,
|
||||
transform=torchvision.transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
||||
batch_size=batch_size_train,
|
||||
shuffle=True)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
||||
'./files/',
|
||||
train=False,
|
||||
download=True,
|
||||
transform=torchvision.transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
||||
batch_size=batch_size_test,
|
||||
shuffle=True)
|
||||
|
||||
|
||||
# Define the GLVQ model plus appropriate feature extractor
|
||||
class CNNGTLVQ(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type="local",
|
||||
prototypes_per_class=2,
|
||||
bottleneck_dim=128,
|
||||
):
|
||||
super(CNNGTLVQ, self).__init__()
|
||||
|
||||
#Feature Extractor - Simple CNN
|
||||
self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
|
||||
nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
|
||||
nn.MaxPool2d(2), nn.Dropout(0.25),
|
||||
nn.Flatten(), nn.Linear(9216, bottleneck_dim),
|
||||
nn.Dropout(0.5), nn.LeakyReLU(),
|
||||
nn.LayerNorm(bottleneck_dim))
|
||||
|
||||
# Forward pass of subspace and prototype initialization data through feature extractor
|
||||
subspace_data = self.fe(subspace_data)
|
||||
prototype_data[0] = self.fe(prototype_data[0])
|
||||
|
||||
# Initialization of GTLVQ
|
||||
self.gtlvq = GTLVQ(num_classes,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type=tangent_projection_type,
|
||||
feature_dim=bottleneck_dim,
|
||||
prototypes_per_class=prototypes_per_class)
|
||||
|
||||
def forward(self, x):
|
||||
# Feature Extraction
|
||||
x = self.fe(x)
|
||||
|
||||
# GTLVQ Forward pass
|
||||
dis = self.gtlvq(x)
|
||||
return dis
|
||||
|
||||
|
||||
# Get init data
|
||||
subspace_data = torch.cat(
|
||||
[next(iter(train_loader))[0],
|
||||
next(iter(test_loader))[0]])
|
||||
prototype_data = next(iter(train_loader))
|
||||
|
||||
# Build the CNN GTLVQ model
|
||||
model = CNNGTLVQ(10,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type="local",
|
||||
bottleneck_dim=128).to(device)
|
||||
|
||||
# Optimize using SGD optimizer from `torch.optim`
|
||||
optimizer = torch.optim.Adam([{
|
||||
'params': model.fe.parameters()
|
||||
}, {
|
||||
'params': model.gtlvq.parameters()
|
||||
}],
|
||||
lr=learning_rate)
|
||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(n_epochs):
|
||||
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
||||
model.train()
|
||||
x_train, y_train = x_train.to(device), y_train.to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
distances = model(x_train)
|
||||
plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||
|
||||
# Compute loss.
|
||||
loss = criterion([distances, plabels], y_train)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
|
||||
model.gtlvq.orthogonalize_subspace()
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
||||
print(
|
||||
f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||
Train Acc: {acc.item():02.02f}')
|
||||
|
||||
# Test
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
for x_test, y_test in test_loader:
|
||||
x_test, y_test = x_test.to(device), y_test.to(device)
|
||||
test_distances = model(torch.tensor(x_test))
|
||||
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||
i = torch.argmin(test_distances, 1)
|
||||
correct += torch.sum(y_test == test_plabels[i])
|
||||
total += y_test.size(0)
|
||||
print('Accuracy of the network on the test images: %d %%' %
|
||||
(torch.true_divide(correct, total) * 100))
|
||||
|
||||
# Save the model
|
||||
PATH = './glvq_mnist_model.pth'
|
||||
torch.save(model.state_dict(), PATH)
|
@ -1,6 +1,8 @@
|
||||
"""ProtoTorch distance functions."""
|
||||
|
||||
import torch
|
||||
from prototorch.functions.helper import equal_int_shape, _int_and_mixed_shape, _check_shapes
|
||||
import numpy as np
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
@ -71,5 +73,155 @@ def lomega_distance(x, y, omegas):
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
||||
r""" Computes an euclidean distanes matrix given two distinct vectors.
|
||||
last dimension must be the vector dimension!
|
||||
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
||||
|
||||
x.shape = (number_of_x_vectors, vector_dim)
|
||||
y.shape = (number_of_y_vectors, vector_dim)
|
||||
|
||||
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
||||
"""
|
||||
for tensor in [x, y]:
|
||||
if tensor.ndim != 2:
|
||||
raise ValueError(
|
||||
'The tensor dimension must be two. You provide: tensor.ndim=' +
|
||||
str(tensor.ndim) + '.')
|
||||
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||
raise ValueError(
|
||||
'The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]='
|
||||
+ str(tuple(x.shape)[1]) + ' and tuple(y.shape)(y)[1]=' +
|
||||
str(tuple(y.shape)[1]) + '.')
|
||||
|
||||
y = torch.transpose(y)
|
||||
|
||||
diss = torch.sum(x**2, axis=1,
|
||||
keepdims=True) - 2 * torch.dot(x, y) + torch.sum(
|
||||
y**2, axis=0, keepdims=True)
|
||||
|
||||
if not squared:
|
||||
if epsilon == 0:
|
||||
diss = torch.sqrt(diss)
|
||||
else:
|
||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||
|
||||
return diss
|
||||
|
||||
|
||||
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
||||
r""" Tangent distances based on the tensorflow implementation of Sascha Saralajews
|
||||
For more info about Tangen distances see DOI:10.1109/IJCNN.2016.7727534.
|
||||
The subspaces is always assumed as transposed and must be orthogonal!
|
||||
For local non sparse signals subspaces must be provided!
|
||||
shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||
subspace should be orthogonalized
|
||||
Pytorch implementation of Sascha Saralajew's tensorflow code.
|
||||
Translation by Christoph Raab
|
||||
"""
|
||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
||||
subspace_int_shape = tuple(subspaces.shape)
|
||||
|
||||
# check if the shapes are correct
|
||||
_check_shapes(signal_int_shape, proto_int_shape)
|
||||
|
||||
atom_axes = list(range(3, len(signal_int_shape)))
|
||||
# for sparse signals, we use the memory efficient implementation
|
||||
if signal_int_shape[1] == 1:
|
||||
signals = torch.reshape(signals, [-1, np.prod(signal_shape[3:])])
|
||||
|
||||
if len(atom_axes) > 1:
|
||||
protos = torch.reshape(protos, [proto_shape[0], -1])
|
||||
|
||||
if subspaces.ndim == 2:
|
||||
# clean solution without map if the matrix_scope is global
|
||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||
subspaces, torch.transpose(subspaces))
|
||||
|
||||
projected_signals = torch.dot(signals, projectors)
|
||||
projected_protos = torch.dot(protos, projectors)
|
||||
|
||||
diss = euclidean_distance_matrix(projected_signals,
|
||||
projected_protos,
|
||||
squared=squared,
|
||||
epsilon=epsilon)
|
||||
|
||||
diss = torch.reshape(
|
||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||
|
||||
return torch.permute(diss, [0, 2, 1])
|
||||
|
||||
else:
|
||||
|
||||
# no solution without map possible --> memory efficient but slow!
|
||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||
subspaces,
|
||||
subspaces) #K.batch_dot(subspaces, subspaces, [2, 2])
|
||||
|
||||
projected_protos = (protos @ subspaces
|
||||
).T #K.batch_dot(projectors, protos, [1, 1]))
|
||||
|
||||
def projected_norm(projector):
|
||||
return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
||||
|
||||
diss = torch.transpose(map(projected_norm, projectors)) \
|
||||
- 2 * torch.dot(signals, projected_protos) \
|
||||
+ torch.sum(projected_protos**2, axis=0, keepdims=True)
|
||||
|
||||
if not squared:
|
||||
if epsilon == 0:
|
||||
diss = torch.sqrt(diss)
|
||||
else:
|
||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||
|
||||
diss = torch.reshape(
|
||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||
|
||||
return torch.permute(diss, [0, 2, 1])
|
||||
|
||||
else:
|
||||
signals = signals.permute([0, 2, 1] + atom_axes)
|
||||
|
||||
diff = signals - protos
|
||||
|
||||
# global tangent space
|
||||
if subspaces.ndim == 2:
|
||||
#Scope Projectors
|
||||
projectors = subspaces #
|
||||
|
||||
#Scope: Tangentspace Projections
|
||||
diff = torch.reshape(
|
||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||
projected_diff = diff @ projectors
|
||||
projected_diff = torch.reshape(
|
||||
projected_diff,
|
||||
(signal_shape[0], signal_shape[2], signal_shape[1]) +
|
||||
signal_shape[3:])
|
||||
|
||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
return diss.permute([0, 2, 1])
|
||||
|
||||
# local tangent spaces
|
||||
else:
|
||||
# Scope: Calculate Projectors
|
||||
projectors = subspaces
|
||||
|
||||
# Scope: Tangentspace Projections
|
||||
diff = torch.reshape(
|
||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||
diff = diff.permute([1, 0, 2])
|
||||
projected_diff = torch.bmm(diff, projectors)
|
||||
projected_diff = torch.reshape(
|
||||
projected_diff,
|
||||
(signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||
signal_shape[3:])
|
||||
|
||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
return diss.permute([1, 0, 2]).squeeze(-1)
|
||||
|
||||
|
||||
# Aliases
|
||||
sed = squared_euclidean_distance
|
||||
|
89
prototorch/functions/helper.py
Normal file
89
prototorch/functions/helper.py
Normal file
@ -0,0 +1,89 @@
|
||||
import torch
|
||||
|
||||
|
||||
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
||||
"""Computes the accuracy of a prototype based model.
|
||||
via Winner-Takes-All rule.
|
||||
Requirement:
|
||||
y_pred.shape == y_true.shape
|
||||
unique(y_pred) in plabels
|
||||
"""
|
||||
with torch.no_grad():
|
||||
idx = torch.argmin(y_pred, axis=1)
|
||||
return torch.true_divide(torch.sum(y_true == plabels[idx]),
|
||||
len(y_pred)) * 100
|
||||
|
||||
|
||||
def predict_label(y_pred, plabels):
|
||||
r""" Predicts labels given a prediction of a prototype based model.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
return plabels[torch.argmin(y_pred, 1)]
|
||||
|
||||
|
||||
def mixed_shape(inputs):
|
||||
if not torch.is_tensor(inputs):
|
||||
raise ValueError('Input must be a tensor.')
|
||||
else:
|
||||
int_shape = list(inputs.shape)
|
||||
# sometimes int_shape returns mixed integer types
|
||||
int_shape = [int(i) if i is not None else i for i in int_shape]
|
||||
tensor_shape = inputs.shape
|
||||
|
||||
for i, s in enumerate(int_shape):
|
||||
if s is None:
|
||||
int_shape[i] = tensor_shape[i]
|
||||
return tuple(int_shape)
|
||||
|
||||
|
||||
def equal_int_shape(shape_1, shape_2):
|
||||
if not isinstance(shape_1,
|
||||
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
||||
raise ValueError('Input shapes must list or tuple.')
|
||||
for shape in [shape_1, shape_2]:
|
||||
if not all([isinstance(x, int) or x is None for x in shape]):
|
||||
raise ValueError(
|
||||
'Input shapes must be list or tuple of int and None values.')
|
||||
|
||||
if len(shape_1) != len(shape_2):
|
||||
return False
|
||||
else:
|
||||
for axis, value in enumerate(shape_1):
|
||||
if value is not None and shape_2[axis] not in {value, None}:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_shapes(signal_int_shape, proto_int_shape):
|
||||
if len(signal_int_shape) < 4:
|
||||
raise ValueError(
|
||||
"The number of signal dimensions must be >=4. You provide: " +
|
||||
str(len(signal_int_shape)))
|
||||
|
||||
if len(proto_int_shape) < 2:
|
||||
raise ValueError(
|
||||
"The number of proto dimensions must be >=2. You provide: " +
|
||||
str(len(proto_int_shape)))
|
||||
|
||||
if not equal_int_shape(signal_int_shape[3:], proto_int_shape[1:]):
|
||||
raise ValueError(
|
||||
"The atom shape of signals must be equal protos. You provide: signals.shape[3:]="
|
||||
+ str(signal_int_shape[3:]) + " != protos.shape[1:]=" +
|
||||
str(proto_int_shape[1:]))
|
||||
|
||||
# not a sparse signal
|
||||
if signal_int_shape[1] != 1:
|
||||
if not equal_int_shape(signal_int_shape[1:2], proto_int_shape[0:1]):
|
||||
raise ValueError(
|
||||
"If the signal is not sparse, the number of prototypes must be equal in signals and "
|
||||
"protos. You provide: " + str(signal_int_shape[1]) + " != " +
|
||||
str(proto_int_shape[0]))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _int_and_mixed_shape(tensor):
|
||||
shape = mixed_shape(tensor)
|
||||
int_shape = tuple([i if isinstance(i, int) else None for i in shape])
|
||||
|
||||
return shape, int_shape
|
37
prototorch/functions/normalization.py
Normal file
37
prototorch/functions/normalization.py
Normal file
@ -0,0 +1,37 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def orthogonalization(tensors):
|
||||
r""" Orthogonalization of a given tensor via polar decomposition.
|
||||
"""
|
||||
u, _, v = torch.svd(tensors, compute_uv=True)
|
||||
u_shape = tuple(list(u.shape))
|
||||
v_shape = tuple(list(v.shape))
|
||||
|
||||
# reshape to (num x N x M)
|
||||
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
|
||||
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
|
||||
|
||||
out = u @ v.permute([0, 2, 1])
|
||||
|
||||
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def trace_normalization(tensors, epsilon=[1e-10]):
|
||||
r""" Trace normalization
|
||||
"""
|
||||
epsilon = torch.tensor([1e-10], dtype=torch.float64)
|
||||
# Scope trace_normalization
|
||||
constant = torch.trace(tensors)
|
||||
|
||||
if epsilon != 0:
|
||||
constant = torch.max(constant, epsilon)
|
||||
|
||||
return tensors / constant
|
190
prototorch/modules/models.py
Normal file
190
prototorch/modules/models.py
Normal file
@ -0,0 +1,190 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.functions.helper import _check_shapes,_int_and_mixed_shape
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
r""" Generalized Tangent Learning Vector Quantization
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes: int
|
||||
Number of classes of the given classification problem.
|
||||
|
||||
subspace_data: torch.tensor of shape (n_batch,feature_dim,feature_dim)
|
||||
Subspace data for the point approximation, required
|
||||
|
||||
prototype_data: torch.tensor of shape (n_init_data,feature_dim) (optional)
|
||||
prototype data for initalization of the prototypes used in GTLVQ.
|
||||
|
||||
tangent_projection_type: string
|
||||
Specifies the tangent projection type
|
||||
options: local
|
||||
local_proj
|
||||
global
|
||||
local: computes the tangent distances without emphasizing projected
|
||||
data. Only distances are available
|
||||
local_proj: computs tangent distances and returns the projected data
|
||||
for further use. Be careful: data is repeated by number of prototypes
|
||||
global: Number of subspaces is set to one and every prototypes
|
||||
uses the same.
|
||||
|
||||
prototypes_per_class: int (default=2,optional)
|
||||
Number of prototypes per class
|
||||
|
||||
feature_dim: int (default=256)
|
||||
Dimensionality of the feature space specified as integer.
|
||||
Prototype dimension.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The GTLVQ [1] is a prototype-based classification learning model. The
|
||||
GTLVQ uses the Tangent-Distances for a local point approximation
|
||||
of an assumed data manifold via prototypial representations.
|
||||
|
||||
The GTLVQ requires subspace projectors for transforming the data
|
||||
and prototypes into the affine subspace. Every prototype is
|
||||
equipped with a specific subpspace and represents a point
|
||||
approximation of the assumed manifold.
|
||||
|
||||
In practice prototypes and data are projected on this manifold
|
||||
and pairwise euclidean distance computes.
|
||||
|
||||
References
|
||||
----------
|
||||
.. [1] Saralajew, Sascha; Villmann, Thomas: Transfer learning
|
||||
in classification based on manifolc. models and its relation
|
||||
to tangent metric learning. In: 2017 International Joint
|
||||
Conference on Neural Networks (IJCNN).
|
||||
Bd. 2017-May : IEEE, 2017, S. 1756–1765
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
subspace_data=None,
|
||||
prototype_data=None,
|
||||
subspace_size=256,
|
||||
tangent_projection_type='local',
|
||||
prototypes_per_class=2,
|
||||
feature_dim=256,
|
||||
):
|
||||
super(GTLVQ, self).__init__()
|
||||
|
||||
self.num_protos = num_classes * prototypes_per_class
|
||||
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
||||
self.feature_dim = feature_dim
|
||||
|
||||
if subspace_data is None:
|
||||
raise ValueError('Init Data must be specified!')
|
||||
|
||||
self.tpt = tangent_projection_type
|
||||
with torch.no_grad():
|
||||
if self.tpt == 'local' or self.tpt == 'local_proj':
|
||||
self.subspaces = torch.nn.Parameter(
|
||||
self.init_local_subspace(
|
||||
subspace_data).clone().detach().requires_grad_(True))
|
||||
elif self.tpt == 'global':
|
||||
self.subspaces = torch.nn.Parameter(
|
||||
self.init_gobal_subspace(
|
||||
subspace_data).clone().detach().requires_grad_(True))
|
||||
else:
|
||||
self.subspaces = None
|
||||
|
||||
# Hypothesis-Margin-Classifier
|
||||
self.cls = Prototypes1D(input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=prototype_data)
|
||||
|
||||
def forward(self, x):
|
||||
# Tangent Projection
|
||||
if self.tpt == 'local_proj':
|
||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2)
|
||||
dis, proj_x = self.local_tangent_projection(
|
||||
x_conform, self.cls.prototypes, self.subspaces)
|
||||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
||||
self.feature_dim)
|
||||
return proj_x, dis
|
||||
elif self.tpt == "local":
|
||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2)
|
||||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
||||
self.subspaces)
|
||||
elif self.tpt == "gloabl":
|
||||
dis = self.global_tangent_distances(x)
|
||||
else:
|
||||
dis = (x @ self.cls.prototypes.T) / (
|
||||
torch.norm(x, dim=1, keepdim=True) @ torch.norm(
|
||||
self.cls.prototypes, dim=1, keepdim=True).T)
|
||||
return dis
|
||||
|
||||
def init_gobal_subspace(self, data, num_subspaces):
|
||||
_, _, v = torch.svd(data)
|
||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
return subspace[:, :num_subspaces]
|
||||
|
||||
def init_local_subspace(self, data):
|
||||
_, _, v = torch.svd(data)
|
||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
return inital_projector.unsqueeze(0).repeat_interleave(
|
||||
self.num_protos, 0)
|
||||
|
||||
def global_tangent_distances(self, x):
|
||||
# Tangent Projection
|
||||
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
|
||||
# Euclidean Distance
|
||||
return euclidean_distance_matrix(x, projected_prototypes)
|
||||
|
||||
def local_tangent_projection(self,
|
||||
signals,
|
||||
protos,
|
||||
subspaces,
|
||||
squared=False,
|
||||
epsilon=1e-10):
|
||||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
||||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||
# subspace should be orthogonalized
|
||||
# Origin Source Code
|
||||
# Origin Author:
|
||||
|
||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||
proto_shape, proto_int_shape = _int_and_mixed_shape(protos)
|
||||
|
||||
# check if the shapes are correct
|
||||
_check_shapes(signal_int_shape, proto_int_shape)
|
||||
|
||||
atom_axes = list(range(3, len(signal_int_shape)))
|
||||
|
||||
# Tangent Data Projections
|
||||
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
|
||||
data = signals.squeeze(2).permute([1, 0, 2])
|
||||
projected_data = torch.bmm(data, subspaces)
|
||||
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
|
||||
diff = projected_data - projected_protos
|
||||
projected_diff = torch.reshape(
|
||||
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||
signal_shape[3:])
|
||||
diss = torch.norm(projected_diff,2,dim=-1)
|
||||
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
|
||||
|
||||
def get_parameters(self):
|
||||
return {
|
||||
"params": self.cls.prototypes,
|
||||
}, {
|
||||
"params": self.subspaces
|
||||
}
|
||||
|
||||
def orthogonalize_subspace(self):
|
||||
if self.subspaces is not None:
|
||||
with torch.no_grad():
|
||||
ortho_subpsaces = orthogonalization(
|
||||
self.subspaces
|
||||
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
|
||||
self.subspaces)
|
||||
self.subspaces.copy_(ortho_subpsaces)
|
Loading…
Reference in New Issue
Block a user