Compare commits
7 Commits
v0.2.0
...
v0.3.0-dev
Author | SHA1 | Date | |
---|---|---|---|
|
ae75b9ebf7 | ||
|
34973808b8 | ||
|
c42df6e203 | ||
|
101b50f4e6 | ||
|
cd9303267b | ||
|
599dfc3fda | ||
|
5b2ab34232 |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.2.0
|
current_version = 0.3.0-dev0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
||||||
|
@@ -11,8 +11,26 @@ Datasets
|
|||||||
|
|
||||||
Functions
|
Functions
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
.. automodule:: prototorch.functions
|
|
||||||
|
**Dimensions:**
|
||||||
|
|
||||||
|
- :math:`B` ... Batch size
|
||||||
|
- :math:`P` ... Number of prototypes
|
||||||
|
- :math:`n_x` ... Data dimension for vectorial data
|
||||||
|
- :math:`n_w` ... Data dimension for vectorial prototypes
|
||||||
|
|
||||||
|
Activations
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
.. automodule:: prototorch.functions.activations
|
||||||
:members:
|
:members:
|
||||||
|
:exclude-members: register_activation, get_activation
|
||||||
|
:undoc-members:
|
||||||
|
|
||||||
|
Distances
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
.. automodule:: prototorch.functions.distances
|
||||||
|
:members:
|
||||||
|
:exclude-members: sed
|
||||||
:undoc-members:
|
:undoc-members:
|
||||||
|
|
||||||
Modules
|
Modules
|
||||||
|
@@ -24,7 +24,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.2.0"
|
release = "0.3.0-dev0"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
@@ -1,11 +1,46 @@
|
|||||||
"""ProtoTorch package."""
|
"""ProtoTorch package."""
|
||||||
|
|
||||||
__version__ = '0.2.0'
|
# #############################################
|
||||||
|
# Core Setup
|
||||||
|
# #############################################
|
||||||
|
__version__ = "0.3.0-dev0"
|
||||||
|
|
||||||
from prototorch import datasets, functions, modules
|
from prototorch import datasets, functions, modules
|
||||||
|
|
||||||
__all__ = [
|
__all_core__ = [
|
||||||
'datasets',
|
"datasets",
|
||||||
'functions',
|
"functions",
|
||||||
'modules',
|
"modules",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# #############################################
|
||||||
|
# Plugin Loader
|
||||||
|
# #############################################
|
||||||
|
import pkgutil
|
||||||
|
import pkg_resources
|
||||||
|
|
||||||
|
__path__ = pkgutil.extend_path(__path__, __name__)
|
||||||
|
|
||||||
|
|
||||||
|
def discover_plugins():
|
||||||
|
return {
|
||||||
|
entry_point.name: entry_point.load()
|
||||||
|
for entry_point in pkg_resources.iter_entry_points("prototorch.plugins")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
discovered_plugins = discover_plugins()
|
||||||
|
locals().update(discovered_plugins)
|
||||||
|
|
||||||
|
# Generate combines __version__ and __all__
|
||||||
|
version_plugins = "\n".join(
|
||||||
|
[
|
||||||
|
"- " + name + ": v" + plugin.__version__
|
||||||
|
for name, plugin in discovered_plugins.items()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if version_plugins != "":
|
||||||
|
version_plugins = "\nPlugins: \n" + version_plugins
|
||||||
|
|
||||||
|
version = "core: v" + __version__ + version_plugins
|
||||||
|
__all__ = __all_core__ + list(discovered_plugins.keys())
|
@@ -14,6 +14,7 @@ import torch
|
|||||||
|
|
||||||
class Dataset(torch.utils.data.Dataset):
|
class Dataset(torch.utils.data.Dataset):
|
||||||
"""Abstract dataset class to be inherited."""
|
"""Abstract dataset class to be inherited."""
|
||||||
|
|
||||||
_repr_indent = 2
|
_repr_indent = 2
|
||||||
|
|
||||||
def __init__(self, root):
|
def __init__(self, root):
|
||||||
@@ -30,8 +31,9 @@ class Dataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
class ProtoDataset(Dataset):
|
class ProtoDataset(Dataset):
|
||||||
"""Abstract dataset class to be inherited."""
|
"""Abstract dataset class to be inherited."""
|
||||||
training_file = 'training.pt'
|
|
||||||
test_file = 'test.pt'
|
training_file = "training.pt"
|
||||||
|
test_file = "test.pt"
|
||||||
|
|
||||||
def __init__(self, root, train=True, download=True, verbose=True):
|
def __init__(self, root, train=True, download=True, verbose=True):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
@@ -39,43 +41,44 @@ class ProtoDataset(Dataset):
|
|||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
if download:
|
if download:
|
||||||
self.download()
|
self._download()
|
||||||
|
|
||||||
if not self._check_exists():
|
if not self._check_exists():
|
||||||
raise RuntimeError('Dataset not found. '
|
raise RuntimeError(
|
||||||
'You can use download=True to download it')
|
"Dataset not found. " "You can use download=True to download it"
|
||||||
|
)
|
||||||
|
|
||||||
data_file = self.training_file if self.train else self.test_file
|
data_file = self.training_file if self.train else self.test_file
|
||||||
|
|
||||||
self.data, self.targets = torch.load(
|
self.data, self.targets = torch.load(
|
||||||
os.path.join(self.processed_folder, data_file))
|
os.path.join(self.processed_folder, data_file)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def raw_folder(self):
|
def raw_folder(self):
|
||||||
return os.path.join(self.root, self.__class__.__name__, 'raw')
|
return os.path.join(self.root, self.__class__.__name__, "raw")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def processed_folder(self):
|
def processed_folder(self):
|
||||||
return os.path.join(self.root, self.__class__.__name__, 'processed')
|
return os.path.join(self.root, self.__class__.__name__, "processed")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def class_to_idx(self):
|
def class_to_idx(self):
|
||||||
return {_class: i for i, _class in enumerate(self.classes)}
|
return {_class: i for i, _class in enumerate(self.classes)}
|
||||||
|
|
||||||
def _check_exists(self):
|
def _check_exists(self):
|
||||||
return (os.path.exists(
|
return os.path.exists(
|
||||||
os.path.join(self.processed_folder, self.training_file))
|
os.path.join(self.processed_folder, self.training_file)
|
||||||
and os.path.exists(
|
) and os.path.exists(os.path.join(self.processed_folder, self.test_file))
|
||||||
os.path.join(self.processed_folder, self.test_file)))
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
head = 'Dataset ' + self.__class__.__name__
|
head = "Dataset " + self.__class__.__name__
|
||||||
body = ['Number of datapoints: {}'.format(self.__len__())]
|
body = ["Number of datapoints: {}".format(self.__len__())]
|
||||||
if self.root is not None:
|
if self.root is not None:
|
||||||
body.append('Root location: {}'.format(self.root))
|
body.append("Root location: {}".format(self.root))
|
||||||
body += self.extra_repr().splitlines()
|
body += self.extra_repr().splitlines()
|
||||||
lines = [head] + [' ' * self._repr_indent + line for line in body]
|
lines = [head] + [" " * self._repr_indent + line for line in body]
|
||||||
return '\n'.join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f"Split: {'Train' if self.train is True else 'Test'}"
|
return f"Split: {'Train' if self.train is True else 'Test'}"
|
||||||
@@ -83,5 +86,5 @@ class ProtoDataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def download(self):
|
def _download(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@@ -46,42 +46,45 @@ from prototorch.datasets.abstract import ProtoDataset
|
|||||||
|
|
||||||
|
|
||||||
class Tecator(ProtoDataset):
|
class Tecator(ProtoDataset):
|
||||||
"""Tecator dataset for classification."""
|
"""
|
||||||
resources = [
|
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__
|
||||||
('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
|
for classification.
|
||||||
'ba5607c580d0f91bb27dc29d13c2f8df'),
|
"""
|
||||||
|
|
||||||
|
_resources = [
|
||||||
|
("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0", "ba5607c580d0f91bb27dc29d13c2f8df"),
|
||||||
] # (google_storage_id, md5hash)
|
] # (google_storage_id, md5hash)
|
||||||
classes = ['0 - low_fat', '1 - high_fat']
|
classes = ["0 - low_fat", "1 - high_fat"]
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
img, target = self.data[index], int(self.targets[index])
|
img, target = self.data[index], int(self.targets[index])
|
||||||
return img, target
|
return img, target
|
||||||
|
|
||||||
def download(self):
|
def _download(self):
|
||||||
"""Download the data if it doesn't exist in already."""
|
"""Download the data if it doesn't exist in already."""
|
||||||
if self._check_exists():
|
if self._check_exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Making directories...')
|
print("Making directories...")
|
||||||
os.makedirs(self.raw_folder, exist_ok=True)
|
os.makedirs(self.raw_folder, exist_ok=True)
|
||||||
os.makedirs(self.processed_folder, exist_ok=True)
|
os.makedirs(self.processed_folder, exist_ok=True)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Downloading...')
|
print("Downloading...")
|
||||||
for fileid, md5 in self.resources:
|
for fileid, md5 in self._resources:
|
||||||
filename = 'tecator.npz'
|
filename = "tecator.npz"
|
||||||
download_file_from_google_drive(fileid,
|
download_file_from_google_drive(
|
||||||
root=self.raw_folder,
|
fileid, root=self.raw_folder, filename=filename, md5=md5
|
||||||
filename=filename,
|
)
|
||||||
md5=md5)
|
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Processing...')
|
print("Processing...")
|
||||||
with np.load(os.path.join(self.raw_folder, 'tecator.npz'),
|
with np.load(
|
||||||
allow_pickle=False) as f:
|
os.path.join(self.raw_folder, "tecator.npz"), allow_pickle=False
|
||||||
x_train, y_train = f['x_train'], f['y_train']
|
) as f:
|
||||||
x_test, y_test = f['x_test'], f['y_test']
|
x_train, y_train = f["x_train"], f["y_train"]
|
||||||
|
x_test, y_test = f["x_test"], f["y_test"]
|
||||||
training_set = [
|
training_set = [
|
||||||
torch.tensor(x_train, dtype=torch.float32),
|
torch.tensor(x_train, dtype=torch.float32),
|
||||||
torch.tensor(y_train),
|
torch.tensor(y_train),
|
||||||
@@ -91,12 +94,10 @@ class Tecator(ProtoDataset):
|
|||||||
torch.tensor(y_test),
|
torch.tensor(y_test),
|
||||||
]
|
]
|
||||||
|
|
||||||
with open(os.path.join(self.processed_folder, self.training_file),
|
with open(os.path.join(self.processed_folder, self.training_file), "wb") as f:
|
||||||
'wb') as f:
|
|
||||||
torch.save(training_set, f)
|
torch.save(training_set, f)
|
||||||
with open(os.path.join(self.processed_folder, self.test_file),
|
with open(os.path.join(self.processed_folder, self.test_file), "wb") as f:
|
||||||
'wb') as f:
|
|
||||||
torch.save(test_set, f)
|
torch.save(test_set, f)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print('Done!')
|
print("Done!")
|
||||||
|
@@ -1,15 +1,24 @@
|
|||||||
"""ProtoTorch distance functions."""
|
"""ProtoTorch distance functions."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.helper import equal_int_shape, _int_and_mixed_shape, _check_shapes
|
from prototorch.functions.helper import (
|
||||||
|
equal_int_shape,
|
||||||
|
_int_and_mixed_shape,
|
||||||
|
_check_shapes,
|
||||||
|
)
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def squared_euclidean_distance(x, y):
|
def squared_euclidean_distance(x, y):
|
||||||
"""Compute the squared Euclidean distance between :math:`x` and :math:`y`.
|
r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
|
||||||
|
|
||||||
Expected dimension of x is 2.
|
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
||||||
Expected dimension of y is 2.
|
|
||||||
|
:param `torch.tensor` x: Two dimensional vector
|
||||||
|
:param `torch.tensor` y: Two dimensional vector
|
||||||
|
|
||||||
|
**Alias:**
|
||||||
|
``prototorch.functions.distances.sed``
|
||||||
"""
|
"""
|
||||||
expanded_x = x.unsqueeze(dim=1)
|
expanded_x = x.unsqueeze(dim=1)
|
||||||
batchwise_difference = y - expanded_x
|
batchwise_difference = y - expanded_x
|
||||||
@@ -19,10 +28,15 @@ def squared_euclidean_distance(x, y):
|
|||||||
|
|
||||||
|
|
||||||
def euclidean_distance(x, y):
|
def euclidean_distance(x, y):
|
||||||
"""Compute the Euclidean distance between :math:`x` and :math:`y`.
|
r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
|
||||||
|
|
||||||
Expected dimension of x is 2.
|
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
||||||
Expected dimension of y is 2.
|
|
||||||
|
:param `torch.tensor` x: Input Tensor of shape :math:`X \times N`
|
||||||
|
:param `torch.tensor` y: Input Tensor of shape :math:`Y \times N`
|
||||||
|
|
||||||
|
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||||
|
:rtype: `torch.tensor`
|
||||||
"""
|
"""
|
||||||
distances_raised = squared_euclidean_distance(x, y)
|
distances_raised = squared_euclidean_distance(x, y)
|
||||||
distances = torch.sqrt(distances_raised)
|
distances = torch.sqrt(distances_raised)
|
||||||
@@ -30,10 +44,17 @@ def euclidean_distance(x, y):
|
|||||||
|
|
||||||
|
|
||||||
def lpnorm_distance(x, y, p):
|
def lpnorm_distance(x, y, p):
|
||||||
r"""Compute :math:`{\langle x, y \rangle}_p`.
|
r"""
|
||||||
|
Calculates the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
||||||
|
Also known as Minkowski distance.
|
||||||
|
|
||||||
Expected dimension of x is 2.
|
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
||||||
Expected dimension of y is 2.
|
|
||||||
|
Calls ``torch.cdist``
|
||||||
|
|
||||||
|
:param `torch.tensor` x: Two dimensional vector
|
||||||
|
:param `torch.tensor` y: Two dimensional vector
|
||||||
|
:param p: p parameter of the lp norm
|
||||||
"""
|
"""
|
||||||
distances = torch.cdist(x, y, p=p)
|
distances = torch.cdist(x, y, p=p)
|
||||||
return distances
|
return distances
|
||||||
@@ -42,11 +63,11 @@ def lpnorm_distance(x, y, p):
|
|||||||
def omega_distance(x, y, omega):
|
def omega_distance(x, y, omega):
|
||||||
r"""Omega distance.
|
r"""Omega distance.
|
||||||
|
|
||||||
Compute :math:`{\langle \Omega x, \Omega y \rangle}_p`
|
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
||||||
|
|
||||||
Expected dimension of x is 2.
|
:param `torch.tensor` x: Two dimensional vector
|
||||||
Expected dimension of y is 2.
|
:param `torch.tensor` y: Two dimensional vector
|
||||||
Expected dimension of omega is 2.
|
:param `torch.tensor` omega: Two dimensional matrix
|
||||||
"""
|
"""
|
||||||
projected_x = x @ omega
|
projected_x = x @ omega
|
||||||
projected_y = y @ omega
|
projected_y = y @ omega
|
||||||
@@ -57,48 +78,55 @@ def omega_distance(x, y, omega):
|
|||||||
def lomega_distance(x, y, omegas):
|
def lomega_distance(x, y, omegas):
|
||||||
r"""Localized Omega distance.
|
r"""Localized Omega distance.
|
||||||
|
|
||||||
Compute :math:`{\langle \Omega_k x, \Omega_k y_k \rangle}_p`
|
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
||||||
|
|
||||||
Expected dimension of x is 2.
|
:param `torch.tensor` x: Two dimensional vector
|
||||||
Expected dimension of y is 2.
|
:param `torch.tensor` y: Two dimensional vector
|
||||||
Expected dimension of omegas is 3.
|
:param `torch.tensor` omegas: Three dimensional matrix
|
||||||
"""
|
"""
|
||||||
projected_x = x @ omegas
|
projected_x = x @ omegas
|
||||||
projected_y = torch.diagonal(y @ omegas).T
|
projected_y = torch.diagonal(y @ omegas).T
|
||||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||||
batchwise_difference = expanded_y - projected_x
|
batchwise_difference = expanded_y - projected_x
|
||||||
differences_squared = batchwise_difference**2
|
differences_squared = batchwise_difference ** 2
|
||||||
distances = torch.sum(differences_squared, dim=2)
|
distances = torch.sum(differences_squared, dim=2)
|
||||||
distances = distances.permute(1, 0)
|
distances = distances.permute(1, 0)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
||||||
r""" Computes an euclidean distanes matrix given two distinct vectors.
|
r"""Computes an euclidean distances matrix given two distinct vectors.
|
||||||
last dimension must be the vector dimension!
|
last dimension must be the vector dimension!
|
||||||
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
|
||||||
|
|
||||||
x.shape = (number_of_x_vectors, vector_dim)
|
- ``x.shape = (number_of_x_vectors, vector_dim)``
|
||||||
y.shape = (number_of_y_vectors, vector_dim)
|
- ``y.shape = (number_of_y_vectors, vector_dim)``
|
||||||
|
|
||||||
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
|
||||||
"""
|
"""
|
||||||
for tensor in [x, y]:
|
for tensor in [x, y]:
|
||||||
if tensor.ndim != 2:
|
if tensor.ndim != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The tensor dimension must be two. You provide: tensor.ndim=' +
|
"The tensor dimension must be two. You provide: tensor.ndim="
|
||||||
str(tensor.ndim) + '.')
|
+ str(tensor.ndim)
|
||||||
|
+ "."
|
||||||
|
)
|
||||||
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]='
|
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
||||||
+ str(tuple(x.shape)[1]) + ' and tuple(y.shape)(y)[1]=' +
|
+ str(tuple(x.shape)[1])
|
||||||
str(tuple(y.shape)[1]) + '.')
|
+ " and tuple(y.shape)(y)[1]="
|
||||||
|
+ str(tuple(y.shape)[1])
|
||||||
|
+ "."
|
||||||
|
)
|
||||||
|
|
||||||
y = torch.transpose(y)
|
y = torch.transpose(y)
|
||||||
|
|
||||||
diss = torch.sum(x**2, axis=1,
|
diss = (
|
||||||
keepdims=True) - 2 * torch.dot(x, y) + torch.sum(
|
torch.sum(x ** 2, axis=1, keepdims=True)
|
||||||
y**2, axis=0, keepdims=True)
|
- 2 * torch.dot(x, y)
|
||||||
|
+ torch.sum(y ** 2, axis=0, keepdims=True)
|
||||||
|
)
|
||||||
|
|
||||||
if not squared:
|
if not squared:
|
||||||
if epsilon == 0:
|
if epsilon == 0:
|
||||||
@@ -110,13 +138,19 @@ def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
|
|
||||||
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
||||||
r""" Tangent distances based on the tensorflow implementation of Sascha Saralajews
|
r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
|
||||||
For more info about Tangen distances see DOI:10.1109/IJCNN.2016.7727534.
|
|
||||||
|
For more info about Tangen distances see
|
||||||
|
|
||||||
|
DOI:10.1109/IJCNN.2016.7727534.
|
||||||
|
|
||||||
The subspaces is always assumed as transposed and must be orthogonal!
|
The subspaces is always assumed as transposed and must be orthogonal!
|
||||||
For local non sparse signals subspaces must be provided!
|
For local non sparse signals subspaces must be provided!
|
||||||
shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
|
||||||
shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
- shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||||
shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
- shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||||
|
- shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||||
|
|
||||||
subspace should be orthogonalized
|
subspace should be orthogonalized
|
||||||
Pytorch implementation of Sascha Saralajew's tensorflow code.
|
Pytorch implementation of Sascha Saralajew's tensorflow code.
|
||||||
Translation by Christoph Raab
|
Translation by Christoph Raab
|
||||||
@@ -139,18 +173,19 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
if subspaces.ndim == 2:
|
if subspaces.ndim == 2:
|
||||||
# clean solution without map if the matrix_scope is global
|
# clean solution without map if the matrix_scope is global
|
||||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||||
subspaces, torch.transpose(subspaces))
|
subspaces, torch.transpose(subspaces)
|
||||||
|
)
|
||||||
|
|
||||||
projected_signals = torch.dot(signals, projectors)
|
projected_signals = torch.dot(signals, projectors)
|
||||||
projected_protos = torch.dot(protos, projectors)
|
projected_protos = torch.dot(protos, projectors)
|
||||||
|
|
||||||
diss = euclidean_distance_matrix(projected_signals,
|
diss = euclidean_distance_matrix(
|
||||||
projected_protos,
|
projected_signals, projected_protos, squared=squared, epsilon=epsilon
|
||||||
squared=squared,
|
)
|
||||||
epsilon=epsilon)
|
|
||||||
|
|
||||||
diss = torch.reshape(
|
diss = torch.reshape(
|
||||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
|
||||||
|
)
|
||||||
|
|
||||||
return torch.permute(diss, [0, 2, 1])
|
return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
@@ -158,18 +193,21 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# no solution without map possible --> memory efficient but slow!
|
# no solution without map possible --> memory efficient but slow!
|
||||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||||
subspaces,
|
subspaces, subspaces
|
||||||
subspaces) #K.batch_dot(subspaces, subspaces, [2, 2])
|
) # K.batch_dot(subspaces, subspaces, [2, 2])
|
||||||
|
|
||||||
projected_protos = (protos @ subspaces
|
projected_protos = (
|
||||||
).T #K.batch_dot(projectors, protos, [1, 1]))
|
protos @ subspaces
|
||||||
|
).T # K.batch_dot(projectors, protos, [1, 1]))
|
||||||
|
|
||||||
def projected_norm(projector):
|
def projected_norm(projector):
|
||||||
return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
return torch.sum(torch.dot(signals, projector) ** 2, axis=1)
|
||||||
|
|
||||||
diss = torch.transpose(map(projected_norm, projectors)) \
|
diss = (
|
||||||
- 2 * torch.dot(signals, projected_protos) \
|
torch.transpose(map(projected_norm, projectors))
|
||||||
+ torch.sum(projected_protos**2, axis=0, keepdims=True)
|
- 2 * torch.dot(signals, projected_protos)
|
||||||
|
+ torch.sum(projected_protos ** 2, axis=0, keepdims=True)
|
||||||
|
)
|
||||||
|
|
||||||
if not squared:
|
if not squared:
|
||||||
if epsilon == 0:
|
if epsilon == 0:
|
||||||
@@ -178,7 +216,8 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||||
|
|
||||||
diss = torch.reshape(
|
diss = torch.reshape(
|
||||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
|
||||||
|
)
|
||||||
|
|
||||||
return torch.permute(diss, [0, 2, 1])
|
return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
@@ -189,17 +228,18 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# global tangent space
|
# global tangent space
|
||||||
if subspaces.ndim == 2:
|
if subspaces.ndim == 2:
|
||||||
#Scope Projectors
|
# Scope Projectors
|
||||||
projectors = subspaces #
|
projectors = subspaces #
|
||||||
|
|
||||||
#Scope: Tangentspace Projections
|
# Scope: Tangentspace Projections
|
||||||
diff = torch.reshape(
|
diff = torch.reshape(
|
||||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
|
||||||
|
)
|
||||||
projected_diff = diff @ projectors
|
projected_diff = diff @ projectors
|
||||||
projected_diff = torch.reshape(
|
projected_diff = torch.reshape(
|
||||||
projected_diff,
|
projected_diff,
|
||||||
(signal_shape[0], signal_shape[2], signal_shape[1]) +
|
(signal_shape[0], signal_shape[2], signal_shape[1]) + signal_shape[3:],
|
||||||
signal_shape[3:])
|
)
|
||||||
|
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
return diss.permute([0, 2, 1])
|
return diss.permute([0, 2, 1])
|
||||||
@@ -211,13 +251,14 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# Scope: Tangentspace Projections
|
# Scope: Tangentspace Projections
|
||||||
diff = torch.reshape(
|
diff = torch.reshape(
|
||||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
|
||||||
|
)
|
||||||
diff = diff.permute([1, 0, 2])
|
diff = diff.permute([1, 0, 2])
|
||||||
projected_diff = torch.bmm(diff, projectors)
|
projected_diff = torch.bmm(diff, projectors)
|
||||||
projected_diff = torch.reshape(
|
projected_diff = torch.reshape(
|
||||||
projected_diff,
|
projected_diff,
|
||||||
(signal_shape[1], signal_shape[0], signal_shape[2]) +
|
(signal_shape[1], signal_shape[0], signal_shape[2]) + signal_shape[3:],
|
||||||
signal_shape[3:])
|
)
|
||||||
|
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
return diss.permute([1, 0, 2]).squeeze(-1)
|
return diss.permute([1, 0, 2]).squeeze(-1)
|
||||||
|
@@ -14,11 +14,11 @@ class _Prototypes(torch.nn.Module):
|
|||||||
|
|
||||||
def _validate_prototype_distribution(self):
|
def _validate_prototype_distribution(self):
|
||||||
if 0 in self.prototype_distribution:
|
if 0 in self.prototype_distribution:
|
||||||
warnings.warn('Are you sure about the `0` in '
|
warnings.warn("Are you sure about the `0` in "
|
||||||
'`prototype_distribution`?')
|
"`prototype_distribution`?")
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
|
return f"prototypes.shape: {tuple(self.prototypes.shape)}"
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
return self.prototypes, self.prototype_labels
|
return self.prototypes, self.prototype_labels
|
||||||
@@ -31,7 +31,7 @@ class Prototypes1D(_Prototypes):
|
|||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='ones',
|
prototype_initializer="ones",
|
||||||
prototype_distribution=None,
|
prototype_distribution=None,
|
||||||
data=None,
|
data=None,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
@@ -44,25 +44,25 @@ class Prototypes1D(_Prototypes):
|
|||||||
prototype_distribution = prototype_distribution.tolist()
|
prototype_distribution = prototype_distribution.tolist()
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
if 'input_dim' not in kwargs:
|
if "input_dim" not in kwargs:
|
||||||
raise NameError('`input_dim` required if '
|
raise NameError("`input_dim` required if "
|
||||||
'no `data` is provided.')
|
"no `data` is provided.")
|
||||||
if prototype_distribution:
|
if prototype_distribution:
|
||||||
kwargs_nclasses = sum(prototype_distribution)
|
kwargs_nclasses = sum(prototype_distribution)
|
||||||
else:
|
else:
|
||||||
if 'nclasses' not in kwargs:
|
if "nclasses" not in kwargs:
|
||||||
raise NameError('`prototype_distribution` required if '
|
raise NameError("`prototype_distribution` required if "
|
||||||
'both `data` and `nclasses` are not '
|
"both `data` and `nclasses` are not "
|
||||||
'provided.')
|
"provided.")
|
||||||
kwargs_nclasses = kwargs.pop('nclasses')
|
kwargs_nclasses = kwargs.pop("nclasses")
|
||||||
input_dim = kwargs.pop('input_dim')
|
input_dim = kwargs.pop("input_dim")
|
||||||
if prototype_initializer in [
|
if prototype_initializer in [
|
||||||
'stratified_mean', 'stratified_random'
|
"stratified_mean", "stratified_random"
|
||||||
]:
|
]:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f'`prototype_initializer`: `{prototype_initializer}` '
|
f"`prototype_initializer`: `{prototype_initializer}` "
|
||||||
'requires `data`, but `data` is not provided. '
|
"requires `data`, but `data` is not provided. "
|
||||||
'Using randomly generated data instead.')
|
"Using randomly generated data instead.")
|
||||||
x_train = torch.rand(kwargs_nclasses, input_dim)
|
x_train = torch.rand(kwargs_nclasses, input_dim)
|
||||||
y_train = torch.arange(kwargs_nclasses)
|
y_train = torch.arange(kwargs_nclasses)
|
||||||
if one_hot_labels:
|
if one_hot_labels:
|
||||||
@@ -75,39 +75,39 @@ class Prototypes1D(_Prototypes):
|
|||||||
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
||||||
|
|
||||||
if nclasses == 1:
|
if nclasses == 1:
|
||||||
warnings.warn('Are you sure about having one class only?')
|
warnings.warn("Are you sure about having one class only?")
|
||||||
|
|
||||||
if x_train.ndim != 2:
|
if x_train.ndim != 2:
|
||||||
raise ValueError('`data[0].ndim != 2`.')
|
raise ValueError("`data[0].ndim != 2`.")
|
||||||
|
|
||||||
if y_train.ndim == 2:
|
if y_train.ndim == 2:
|
||||||
if y_train.shape[1] == 1 and one_hot_labels:
|
if y_train.shape[1] == 1 and one_hot_labels:
|
||||||
raise ValueError('`one_hot_labels` is set to `True` '
|
raise ValueError("`one_hot_labels` is set to `True` "
|
||||||
'but target labels are not one-hot-encoded.')
|
"but target labels are not one-hot-encoded.")
|
||||||
if y_train.shape[1] != 1 and not one_hot_labels:
|
if y_train.shape[1] != 1 and not one_hot_labels:
|
||||||
raise ValueError('`one_hot_labels` is set to `False` '
|
raise ValueError("`one_hot_labels` is set to `False` "
|
||||||
'but target labels in `data` '
|
"but target labels in `data` "
|
||||||
'are one-hot-encoded.')
|
"are one-hot-encoded.")
|
||||||
if y_train.ndim == 1 and one_hot_labels:
|
if y_train.ndim == 1 and one_hot_labels:
|
||||||
raise ValueError('`one_hot_labels` is set to `True` '
|
raise ValueError("`one_hot_labels` is set to `True` "
|
||||||
'but target labels are not one-hot-encoded.')
|
"but target labels are not one-hot-encoded.")
|
||||||
|
|
||||||
# Verify input dimension if `input_dim` is provided
|
# Verify input dimension if `input_dim` is provided
|
||||||
if 'input_dim' in kwargs:
|
if "input_dim" in kwargs:
|
||||||
input_dim = kwargs.pop('input_dim')
|
input_dim = kwargs.pop("input_dim")
|
||||||
if input_dim != x_train.shape[1]:
|
if input_dim != x_train.shape[1]:
|
||||||
raise ValueError(f'Provided `input_dim`={input_dim} does '
|
raise ValueError(f"Provided `input_dim`={input_dim} does "
|
||||||
'not match data dimension '
|
"not match data dimension "
|
||||||
f'`data[0].shape[1]`={x_train.shape[1]}')
|
f"`data[0].shape[1]`={x_train.shape[1]}")
|
||||||
|
|
||||||
# Verify the number of classes if `nclasses` is provided
|
# Verify the number of classes if `nclasses` is provided
|
||||||
if 'nclasses' in kwargs:
|
if "nclasses" in kwargs:
|
||||||
kwargs_nclasses = kwargs.pop('nclasses')
|
kwargs_nclasses = kwargs.pop("nclasses")
|
||||||
if kwargs_nclasses != nclasses:
|
if kwargs_nclasses != nclasses:
|
||||||
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
|
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
|
||||||
'not match data labels '
|
"not match data labels "
|
||||||
'`torch.unique(data[1]).shape[0]`'
|
"`torch.unique(data[1]).shape[0]`"
|
||||||
f'={nclasses}')
|
f"={nclasses}")
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
87
setup.py
87
setup.py
@@ -1,5 +1,13 @@
|
|||||||
"""Install ProtoTorch."""
|
"""
|
||||||
|
_____ _ _______ _
|
||||||
|
| __ \ | | |__ __| | |
|
||||||
|
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
|
||||||
|
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
|
||||||
|
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|
||||||
|
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|
|
||||||
|
|
||||||
|
ProtoTorch Core Package
|
||||||
|
"""
|
||||||
from setuptools import setup
|
from setuptools import setup
|
||||||
from setuptools import find_packages
|
from setuptools import find_packages
|
||||||
|
|
||||||
@@ -32,40 +40,43 @@ EXAMPLES = [
|
|||||||
TESTS = ["pytest"]
|
TESTS = ["pytest"]
|
||||||
ALL = DOCS + DATASETS + EXAMPLES + TESTS
|
ALL = DOCS + DATASETS + EXAMPLES + TESTS
|
||||||
|
|
||||||
setup(name="prototorch",
|
setup(
|
||||||
version="0.2.0",
|
name="prototorch",
|
||||||
description="Highly extensible, GPU-supported "
|
version="0.3.0-dev0",
|
||||||
"Learning Vector Quantization (LVQ) toolbox "
|
description="Highly extensible, GPU-supported "
|
||||||
"built using PyTorch and its nn API.",
|
"Learning Vector Quantization (LVQ) toolbox "
|
||||||
long_description=long_description,
|
"built using PyTorch and its nn API.",
|
||||||
long_description_content_type="text/markdown",
|
long_description=long_description,
|
||||||
author="Jensun Ravichandran",
|
long_description_content_type="text/markdown",
|
||||||
author_email="jjensun@gmail.com",
|
author="Jensun Ravichandran",
|
||||||
url=PROJECT_URL,
|
author_email="jjensun@gmail.com",
|
||||||
download_url=DOWNLOAD_URL,
|
url=PROJECT_URL,
|
||||||
license="MIT",
|
download_url=DOWNLOAD_URL,
|
||||||
install_requires=INSTALL_REQUIRES,
|
license="MIT",
|
||||||
extras_require={
|
install_requires=INSTALL_REQUIRES,
|
||||||
"docs": DOCS,
|
extras_require={
|
||||||
"datasets": DATASETS,
|
"docs": DOCS,
|
||||||
"examples": EXAMPLES,
|
"datasets": DATASETS,
|
||||||
"tests": TESTS,
|
"examples": EXAMPLES,
|
||||||
"all": ALL,
|
"tests": TESTS,
|
||||||
},
|
"all": ALL,
|
||||||
classifiers=[
|
},
|
||||||
"Development Status :: 2 - Pre-Alpha",
|
classifiers=[
|
||||||
"Environment :: Console",
|
"Development Status :: 2 - Pre-Alpha",
|
||||||
"Intended Audience :: Developers",
|
"Environment :: Console",
|
||||||
"Intended Audience :: Education",
|
"Intended Audience :: Developers",
|
||||||
"Intended Audience :: Science/Research",
|
"Intended Audience :: Education",
|
||||||
"License :: OSI Approved :: MIT License",
|
"Intended Audience :: Science/Research",
|
||||||
"Natural Language :: English",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Programming Language :: Python :: 3.6",
|
"Natural Language :: English",
|
||||||
"Programming Language :: Python :: 3.7",
|
"Programming Language :: Python :: 3.6",
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.7",
|
||||||
"Operating System :: OS Independent",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Operating System :: OS Independent",
|
||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries",
|
||||||
],
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
packages=find_packages())
|
],
|
||||||
|
packages=find_packages(),
|
||||||
|
zip_safe=False,
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user