7 Commits

Author SHA1 Message Date
Alexander Engelsberger
ae75b9ebf7 Bump version: 0.2.0 → 0.3.0-dev0 2021-04-21 14:57:45 +02:00
Alexander Engelsberger
34973808b8 Improve documentation. 2021-04-21 14:55:54 +02:00
Alexander Engelsberger
c42df6e203 Merge version 0.2.0 into feature/plugin-architecture. 2021-04-19 16:44:26 +02:00
Jensun Ravichandran
101b50f4e6 Update prototypes.py
Changes:
1. Change single-quotes to double-quotes.
2021-04-15 12:35:06 +02:00
Alexander Engelsberger
cd9303267b Use git version. 2021-04-14 13:48:00 +02:00
Alexander Engelsberger
599dfc3fda Fix issue with plugin subpackage import. 2021-04-13 22:55:49 +02:00
Alexander Engelsberger
5b2ab34232 Add plugin loader. 2021-04-13 12:36:22 +02:00
9 changed files with 292 additions and 183 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.2.0 current_version = 0.3.0-dev0
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))? parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?

View File

@@ -11,8 +11,26 @@ Datasets
Functions Functions
-------------------------------------- --------------------------------------
.. automodule:: prototorch.functions
**Dimensions:**
- :math:`B` ... Batch size
- :math:`P` ... Number of prototypes
- :math:`n_x` ... Data dimension for vectorial data
- :math:`n_w` ... Data dimension for vectorial prototypes
Activations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: prototorch.functions.activations
:members: :members:
:exclude-members: register_activation, get_activation
:undoc-members:
Distances
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: prototorch.functions.distances
:members:
:exclude-members: sed
:undoc-members: :undoc-members:
Modules Modules

View File

@@ -24,7 +24,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.2.0" release = "0.3.0-dev0"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -1,11 +1,46 @@
"""ProtoTorch package.""" """ProtoTorch package."""
__version__ = '0.2.0' # #############################################
# Core Setup
# #############################################
__version__ = "0.3.0-dev0"
from prototorch import datasets, functions, modules from prototorch import datasets, functions, modules
__all__ = [ __all_core__ = [
'datasets', "datasets",
'functions', "functions",
'modules', "modules",
] ]
# #############################################
# Plugin Loader
# #############################################
import pkgutil
import pkg_resources
__path__ = pkgutil.extend_path(__path__, __name__)
def discover_plugins():
return {
entry_point.name: entry_point.load()
for entry_point in pkg_resources.iter_entry_points("prototorch.plugins")
}
discovered_plugins = discover_plugins()
locals().update(discovered_plugins)
# Generate combines __version__ and __all__
version_plugins = "\n".join(
[
"- " + name + ": v" + plugin.__version__
for name, plugin in discovered_plugins.items()
]
)
if version_plugins != "":
version_plugins = "\nPlugins: \n" + version_plugins
version = "core: v" + __version__ + version_plugins
__all__ = __all_core__ + list(discovered_plugins.keys())

View File

@@ -14,6 +14,7 @@ import torch
class Dataset(torch.utils.data.Dataset): class Dataset(torch.utils.data.Dataset):
"""Abstract dataset class to be inherited.""" """Abstract dataset class to be inherited."""
_repr_indent = 2 _repr_indent = 2
def __init__(self, root): def __init__(self, root):
@@ -30,8 +31,9 @@ class Dataset(torch.utils.data.Dataset):
class ProtoDataset(Dataset): class ProtoDataset(Dataset):
"""Abstract dataset class to be inherited.""" """Abstract dataset class to be inherited."""
training_file = 'training.pt'
test_file = 'test.pt' training_file = "training.pt"
test_file = "test.pt"
def __init__(self, root, train=True, download=True, verbose=True): def __init__(self, root, train=True, download=True, verbose=True):
super().__init__(root) super().__init__(root)
@@ -39,43 +41,44 @@ class ProtoDataset(Dataset):
self.verbose = verbose self.verbose = verbose
if download: if download:
self.download() self._download()
if not self._check_exists(): if not self._check_exists():
raise RuntimeError('Dataset not found. ' raise RuntimeError(
'You can use download=True to download it') "Dataset not found. " "You can use download=True to download it"
)
data_file = self.training_file if self.train else self.test_file data_file = self.training_file if self.train else self.test_file
self.data, self.targets = torch.load( self.data, self.targets = torch.load(
os.path.join(self.processed_folder, data_file)) os.path.join(self.processed_folder, data_file)
)
@property @property
def raw_folder(self): def raw_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'raw') return os.path.join(self.root, self.__class__.__name__, "raw")
@property @property
def processed_folder(self): def processed_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'processed') return os.path.join(self.root, self.__class__.__name__, "processed")
@property @property
def class_to_idx(self): def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)} return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self): def _check_exists(self):
return (os.path.exists( return os.path.exists(
os.path.join(self.processed_folder, self.training_file)) os.path.join(self.processed_folder, self.training_file)
and os.path.exists( ) and os.path.exists(os.path.join(self.processed_folder, self.test_file))
os.path.join(self.processed_folder, self.test_file)))
def __repr__(self): def __repr__(self):
head = 'Dataset ' + self.__class__.__name__ head = "Dataset " + self.__class__.__name__
body = ['Number of datapoints: {}'.format(self.__len__())] body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None: if self.root is not None:
body.append('Root location: {}'.format(self.root)) body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines() body += self.extra_repr().splitlines()
lines = [head] + [' ' * self._repr_indent + line for line in body] lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines) return "\n".join(lines)
def extra_repr(self): def extra_repr(self):
return f"Split: {'Train' if self.train is True else 'Test'}" return f"Split: {'Train' if self.train is True else 'Test'}"
@@ -83,5 +86,5 @@ class ProtoDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def download(self): def _download(self):
raise NotImplementedError raise NotImplementedError

View File

@@ -46,42 +46,45 @@ from prototorch.datasets.abstract import ProtoDataset
class Tecator(ProtoDataset): class Tecator(ProtoDataset):
"""Tecator dataset for classification.""" """
resources = [ `Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__
('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0', for classification.
'ba5607c580d0f91bb27dc29d13c2f8df'), """
_resources = [
("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0", "ba5607c580d0f91bb27dc29d13c2f8df"),
] # (google_storage_id, md5hash) ] # (google_storage_id, md5hash)
classes = ['0 - low_fat', '1 - high_fat'] classes = ["0 - low_fat", "1 - high_fat"]
def __getitem__(self, index): def __getitem__(self, index):
img, target = self.data[index], int(self.targets[index]) img, target = self.data[index], int(self.targets[index])
return img, target return img, target
def download(self): def _download(self):
"""Download the data if it doesn't exist in already.""" """Download the data if it doesn't exist in already."""
if self._check_exists(): if self._check_exists():
return return
if self.verbose: if self.verbose:
print('Making directories...') print("Making directories...")
os.makedirs(self.raw_folder, exist_ok=True) os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True) os.makedirs(self.processed_folder, exist_ok=True)
if self.verbose: if self.verbose:
print('Downloading...') print("Downloading...")
for fileid, md5 in self.resources: for fileid, md5 in self._resources:
filename = 'tecator.npz' filename = "tecator.npz"
download_file_from_google_drive(fileid, download_file_from_google_drive(
root=self.raw_folder, fileid, root=self.raw_folder, filename=filename, md5=md5
filename=filename, )
md5=md5)
if self.verbose: if self.verbose:
print('Processing...') print("Processing...")
with np.load(os.path.join(self.raw_folder, 'tecator.npz'), with np.load(
allow_pickle=False) as f: os.path.join(self.raw_folder, "tecator.npz"), allow_pickle=False
x_train, y_train = f['x_train'], f['y_train'] ) as f:
x_test, y_test = f['x_test'], f['y_test'] x_train, y_train = f["x_train"], f["y_train"]
x_test, y_test = f["x_test"], f["y_test"]
training_set = [ training_set = [
torch.tensor(x_train, dtype=torch.float32), torch.tensor(x_train, dtype=torch.float32),
torch.tensor(y_train), torch.tensor(y_train),
@@ -91,12 +94,10 @@ class Tecator(ProtoDataset):
torch.tensor(y_test), torch.tensor(y_test),
] ]
with open(os.path.join(self.processed_folder, self.training_file), with open(os.path.join(self.processed_folder, self.training_file), "wb") as f:
'wb') as f:
torch.save(training_set, f) torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self.test_file), with open(os.path.join(self.processed_folder, self.test_file), "wb") as f:
'wb') as f:
torch.save(test_set, f) torch.save(test_set, f)
if self.verbose: if self.verbose:
print('Done!') print("Done!")

View File

@@ -1,15 +1,24 @@
"""ProtoTorch distance functions.""" """ProtoTorch distance functions."""
import torch import torch
from prototorch.functions.helper import equal_int_shape, _int_and_mixed_shape, _check_shapes from prototorch.functions.helper import (
equal_int_shape,
_int_and_mixed_shape,
_check_shapes,
)
import numpy as np import numpy as np
def squared_euclidean_distance(x, y): def squared_euclidean_distance(x, y):
"""Compute the squared Euclidean distance between :math:`x` and :math:`y`. r"""Compute the squared Euclidean distance between :math:`\bm x` and :math:`\bm y`.
Expected dimension of x is 2. Compute :math:`{\langle \bm x - \bm y \rangle}_2`
Expected dimension of y is 2.
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
**Alias:**
``prototorch.functions.distances.sed``
""" """
expanded_x = x.unsqueeze(dim=1) expanded_x = x.unsqueeze(dim=1)
batchwise_difference = y - expanded_x batchwise_difference = y - expanded_x
@@ -19,10 +28,15 @@ def squared_euclidean_distance(x, y):
def euclidean_distance(x, y): def euclidean_distance(x, y):
"""Compute the Euclidean distance between :math:`x` and :math:`y`. r"""Compute the Euclidean distance between :math:`x` and :math:`y`.
Expected dimension of x is 2. Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
Expected dimension of y is 2.
:param `torch.tensor` x: Input Tensor of shape :math:`X \times N`
:param `torch.tensor` y: Input Tensor of shape :math:`Y \times N`
:returns: Distance Tensor of shape :math:`X \times Y`
:rtype: `torch.tensor`
""" """
distances_raised = squared_euclidean_distance(x, y) distances_raised = squared_euclidean_distance(x, y)
distances = torch.sqrt(distances_raised) distances = torch.sqrt(distances_raised)
@@ -30,10 +44,17 @@ def euclidean_distance(x, y):
def lpnorm_distance(x, y, p): def lpnorm_distance(x, y, p):
r"""Compute :math:`{\langle x, y \rangle}_p`. r"""
Calculates the lp-norm between :math:`\bm x` and :math:`\bm y`.
Also known as Minkowski distance.
Expected dimension of x is 2. Compute :math:`{\| \bm x - \bm y \|}_p`.
Expected dimension of y is 2.
Calls ``torch.cdist``
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
:param p: p parameter of the lp norm
""" """
distances = torch.cdist(x, y, p=p) distances = torch.cdist(x, y, p=p)
return distances return distances
@@ -42,11 +63,11 @@ def lpnorm_distance(x, y, p):
def omega_distance(x, y, omega): def omega_distance(x, y, omega):
r"""Omega distance. r"""Omega distance.
Compute :math:`{\langle \Omega x, \Omega y \rangle}_p` Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
Expected dimension of x is 2. :param `torch.tensor` x: Two dimensional vector
Expected dimension of y is 2. :param `torch.tensor` y: Two dimensional vector
Expected dimension of omega is 2. :param `torch.tensor` omega: Two dimensional matrix
""" """
projected_x = x @ omega projected_x = x @ omega
projected_y = y @ omega projected_y = y @ omega
@@ -57,48 +78,55 @@ def omega_distance(x, y, omega):
def lomega_distance(x, y, omegas): def lomega_distance(x, y, omegas):
r"""Localized Omega distance. r"""Localized Omega distance.
Compute :math:`{\langle \Omega_k x, \Omega_k y_k \rangle}_p` Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
Expected dimension of x is 2. :param `torch.tensor` x: Two dimensional vector
Expected dimension of y is 2. :param `torch.tensor` y: Two dimensional vector
Expected dimension of omegas is 3. :param `torch.tensor` omegas: Three dimensional matrix
""" """
projected_x = x @ omegas projected_x = x @ omegas
projected_y = torch.diagonal(y @ omegas).T projected_y = torch.diagonal(y @ omegas).T
expanded_y = torch.unsqueeze(projected_y, dim=1) expanded_y = torch.unsqueeze(projected_y, dim=1)
batchwise_difference = expanded_y - projected_x batchwise_difference = expanded_y - projected_x
differences_squared = batchwise_difference**2 differences_squared = batchwise_difference ** 2
distances = torch.sum(differences_squared, dim=2) distances = torch.sum(differences_squared, dim=2)
distances = distances.permute(1, 0) distances = distances.permute(1, 0)
return distances return distances
def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10): def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
r""" Computes an euclidean distanes matrix given two distinct vectors. r"""Computes an euclidean distances matrix given two distinct vectors.
last dimension must be the vector dimension! last dimension must be the vector dimension!
compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction! compute the distance via the identity of the dot product. This avoids the memory overhead due to the subtraction!
x.shape = (number_of_x_vectors, vector_dim) - ``x.shape = (number_of_x_vectors, vector_dim)``
y.shape = (number_of_y_vectors, vector_dim) - ``y.shape = (number_of_y_vectors, vector_dim)``
output: matrix of distances (number_of_x_vectors, number_of_y_vectors) output: matrix of distances (number_of_x_vectors, number_of_y_vectors)
""" """
for tensor in [x, y]: for tensor in [x, y]:
if tensor.ndim != 2: if tensor.ndim != 2:
raise ValueError( raise ValueError(
'The tensor dimension must be two. You provide: tensor.ndim=' + "The tensor dimension must be two. You provide: tensor.ndim="
str(tensor.ndim) + '.') + str(tensor.ndim)
+ "."
)
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]): if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
raise ValueError( raise ValueError(
'The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]=' "The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
+ str(tuple(x.shape)[1]) + ' and tuple(y.shape)(y)[1]=' + + str(tuple(x.shape)[1])
str(tuple(y.shape)[1]) + '.') + " and tuple(y.shape)(y)[1]="
+ str(tuple(y.shape)[1])
+ "."
)
y = torch.transpose(y) y = torch.transpose(y)
diss = torch.sum(x**2, axis=1, diss = (
keepdims=True) - 2 * torch.dot(x, y) + torch.sum( torch.sum(x ** 2, axis=1, keepdims=True)
y**2, axis=0, keepdims=True) - 2 * torch.dot(x, y)
+ torch.sum(y ** 2, axis=0, keepdims=True)
)
if not squared: if not squared:
if epsilon == 0: if epsilon == 0:
@@ -110,13 +138,19 @@ def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10): def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
r""" Tangent distances based on the tensorflow implementation of Sascha Saralajews r"""Tangent distances based on the tensorflow implementation of Sascha Saralajews
For more info about Tangen distances see DOI:10.1109/IJCNN.2016.7727534.
For more info about Tangen distances see
DOI:10.1109/IJCNN.2016.7727534.
The subspaces is always assumed as transposed and must be orthogonal! The subspaces is always assumed as transposed and must be orthogonal!
For local non sparse signals subspaces must be provided! For local non sparse signals subspaces must be provided!
shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
shape(protos): proto_number x dim1 x dim2 x ... x dimN - shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape) - shape(protos): proto_number x dim1 x dim2 x ... x dimN
- shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
subspace should be orthogonalized subspace should be orthogonalized
Pytorch implementation of Sascha Saralajew's tensorflow code. Pytorch implementation of Sascha Saralajew's tensorflow code.
Translation by Christoph Raab Translation by Christoph Raab
@@ -139,18 +173,19 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
if subspaces.ndim == 2: if subspaces.ndim == 2:
# clean solution without map if the matrix_scope is global # clean solution without map if the matrix_scope is global
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot( projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
subspaces, torch.transpose(subspaces)) subspaces, torch.transpose(subspaces)
)
projected_signals = torch.dot(signals, projectors) projected_signals = torch.dot(signals, projectors)
projected_protos = torch.dot(protos, projectors) projected_protos = torch.dot(protos, projectors)
diss = euclidean_distance_matrix(projected_signals, diss = euclidean_distance_matrix(
projected_protos, projected_signals, projected_protos, squared=squared, epsilon=epsilon
squared=squared, )
epsilon=epsilon)
diss = torch.reshape( diss = torch.reshape(
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
)
return torch.permute(diss, [0, 2, 1]) return torch.permute(diss, [0, 2, 1])
@@ -158,18 +193,21 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
# no solution without map possible --> memory efficient but slow! # no solution without map possible --> memory efficient but slow!
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm( projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
subspaces, subspaces, subspaces
subspaces) #K.batch_dot(subspaces, subspaces, [2, 2]) ) # K.batch_dot(subspaces, subspaces, [2, 2])
projected_protos = (protos @ subspaces projected_protos = (
).T #K.batch_dot(projectors, protos, [1, 1])) protos @ subspaces
).T # K.batch_dot(projectors, protos, [1, 1]))
def projected_norm(projector): def projected_norm(projector):
return torch.sum(torch.dot(signals, projector)**2, axis=1) return torch.sum(torch.dot(signals, projector) ** 2, axis=1)
diss = torch.transpose(map(projected_norm, projectors)) \ diss = (
- 2 * torch.dot(signals, projected_protos) \ torch.transpose(map(projected_norm, projectors))
+ torch.sum(projected_protos**2, axis=0, keepdims=True) - 2 * torch.dot(signals, projected_protos)
+ torch.sum(projected_protos ** 2, axis=0, keepdims=True)
)
if not squared: if not squared:
if epsilon == 0: if epsilon == 0:
@@ -178,7 +216,8 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
diss = torch.sqrt(torch.max(diss, epsilon)) diss = torch.sqrt(torch.max(diss, epsilon))
diss = torch.reshape( diss = torch.reshape(
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]) diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
)
return torch.permute(diss, [0, 2, 1]) return torch.permute(diss, [0, 2, 1])
@@ -189,17 +228,18 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
# global tangent space # global tangent space
if subspaces.ndim == 2: if subspaces.ndim == 2:
#Scope Projectors # Scope Projectors
projectors = subspaces # projectors = subspaces #
#Scope: Tangentspace Projections # Scope: Tangentspace Projections
diff = torch.reshape( diff = torch.reshape(
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
)
projected_diff = diff @ projectors projected_diff = diff @ projectors
projected_diff = torch.reshape( projected_diff = torch.reshape(
projected_diff, projected_diff,
(signal_shape[0], signal_shape[2], signal_shape[1]) + (signal_shape[0], signal_shape[2], signal_shape[1]) + signal_shape[3:],
signal_shape[3:]) )
diss = torch.norm(projected_diff, 2, dim=-1) diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([0, 2, 1]) return diss.permute([0, 2, 1])
@@ -211,13 +251,14 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
# Scope: Tangentspace Projections # Scope: Tangentspace Projections
diff = torch.reshape( diff = torch.reshape(
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)) diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
)
diff = diff.permute([1, 0, 2]) diff = diff.permute([1, 0, 2])
projected_diff = torch.bmm(diff, projectors) projected_diff = torch.bmm(diff, projectors)
projected_diff = torch.reshape( projected_diff = torch.reshape(
projected_diff, projected_diff,
(signal_shape[1], signal_shape[0], signal_shape[2]) + (signal_shape[1], signal_shape[0], signal_shape[2]) + signal_shape[3:],
signal_shape[3:]) )
diss = torch.norm(projected_diff, 2, dim=-1) diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1) return diss.permute([1, 0, 2]).squeeze(-1)

View File

@@ -14,11 +14,11 @@ class _Prototypes(torch.nn.Module):
def _validate_prototype_distribution(self): def _validate_prototype_distribution(self):
if 0 in self.prototype_distribution: if 0 in self.prototype_distribution:
warnings.warn('Are you sure about the `0` in ' warnings.warn("Are you sure about the `0` in "
'`prototype_distribution`?') "`prototype_distribution`?")
def extra_repr(self): def extra_repr(self):
return f'prototypes.shape: {tuple(self.prototypes.shape)}' return f"prototypes.shape: {tuple(self.prototypes.shape)}"
def forward(self): def forward(self):
return self.prototypes, self.prototype_labels return self.prototypes, self.prototype_labels
@@ -31,7 +31,7 @@ class Prototypes1D(_Prototypes):
""" """
def __init__(self, def __init__(self,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer='ones', prototype_initializer="ones",
prototype_distribution=None, prototype_distribution=None,
data=None, data=None,
dtype=torch.float32, dtype=torch.float32,
@@ -44,25 +44,25 @@ class Prototypes1D(_Prototypes):
prototype_distribution = prototype_distribution.tolist() prototype_distribution = prototype_distribution.tolist()
if data is None: if data is None:
if 'input_dim' not in kwargs: if "input_dim" not in kwargs:
raise NameError('`input_dim` required if ' raise NameError("`input_dim` required if "
'no `data` is provided.') "no `data` is provided.")
if prototype_distribution: if prototype_distribution:
kwargs_nclasses = sum(prototype_distribution) kwargs_nclasses = sum(prototype_distribution)
else: else:
if 'nclasses' not in kwargs: if "nclasses" not in kwargs:
raise NameError('`prototype_distribution` required if ' raise NameError("`prototype_distribution` required if "
'both `data` and `nclasses` are not ' "both `data` and `nclasses` are not "
'provided.') "provided.")
kwargs_nclasses = kwargs.pop('nclasses') kwargs_nclasses = kwargs.pop("nclasses")
input_dim = kwargs.pop('input_dim') input_dim = kwargs.pop("input_dim")
if prototype_initializer in [ if prototype_initializer in [
'stratified_mean', 'stratified_random' "stratified_mean", "stratified_random"
]: ]:
warnings.warn( warnings.warn(
f'`prototype_initializer`: `{prototype_initializer}` ' f"`prototype_initializer`: `{prototype_initializer}` "
'requires `data`, but `data` is not provided. ' "requires `data`, but `data` is not provided. "
'Using randomly generated data instead.') "Using randomly generated data instead.")
x_train = torch.rand(kwargs_nclasses, input_dim) x_train = torch.rand(kwargs_nclasses, input_dim)
y_train = torch.arange(kwargs_nclasses) y_train = torch.arange(kwargs_nclasses)
if one_hot_labels: if one_hot_labels:
@@ -75,39 +75,39 @@ class Prototypes1D(_Prototypes):
nclasses = torch.unique(y_train, dim=-1).shape[-1] nclasses = torch.unique(y_train, dim=-1).shape[-1]
if nclasses == 1: if nclasses == 1:
warnings.warn('Are you sure about having one class only?') warnings.warn("Are you sure about having one class only?")
if x_train.ndim != 2: if x_train.ndim != 2:
raise ValueError('`data[0].ndim != 2`.') raise ValueError("`data[0].ndim != 2`.")
if y_train.ndim == 2: if y_train.ndim == 2:
if y_train.shape[1] == 1 and one_hot_labels: if y_train.shape[1] == 1 and one_hot_labels:
raise ValueError('`one_hot_labels` is set to `True` ' raise ValueError("`one_hot_labels` is set to `True` "
'but target labels are not one-hot-encoded.') "but target labels are not one-hot-encoded.")
if y_train.shape[1] != 1 and not one_hot_labels: if y_train.shape[1] != 1 and not one_hot_labels:
raise ValueError('`one_hot_labels` is set to `False` ' raise ValueError("`one_hot_labels` is set to `False` "
'but target labels in `data` ' "but target labels in `data` "
'are one-hot-encoded.') "are one-hot-encoded.")
if y_train.ndim == 1 and one_hot_labels: if y_train.ndim == 1 and one_hot_labels:
raise ValueError('`one_hot_labels` is set to `True` ' raise ValueError("`one_hot_labels` is set to `True` "
'but target labels are not one-hot-encoded.') "but target labels are not one-hot-encoded.")
# Verify input dimension if `input_dim` is provided # Verify input dimension if `input_dim` is provided
if 'input_dim' in kwargs: if "input_dim" in kwargs:
input_dim = kwargs.pop('input_dim') input_dim = kwargs.pop("input_dim")
if input_dim != x_train.shape[1]: if input_dim != x_train.shape[1]:
raise ValueError(f'Provided `input_dim`={input_dim} does ' raise ValueError(f"Provided `input_dim`={input_dim} does "
'not match data dimension ' "not match data dimension "
f'`data[0].shape[1]`={x_train.shape[1]}') f"`data[0].shape[1]`={x_train.shape[1]}")
# Verify the number of classes if `nclasses` is provided # Verify the number of classes if `nclasses` is provided
if 'nclasses' in kwargs: if "nclasses" in kwargs:
kwargs_nclasses = kwargs.pop('nclasses') kwargs_nclasses = kwargs.pop("nclasses")
if kwargs_nclasses != nclasses: if kwargs_nclasses != nclasses:
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does ' raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
'not match data labels ' "not match data labels "
'`torch.unique(data[1]).shape[0]`' "`torch.unique(data[1]).shape[0]`"
f'={nclasses}') f"={nclasses}")
super().__init__(**kwargs) super().__init__(**kwargs)

View File

@@ -1,5 +1,13 @@
"""Install ProtoTorch.""" """
_____ _ _______ _
| __ \ | | |__ __| | |
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|
ProtoTorch Core Package
"""
from setuptools import setup from setuptools import setup
from setuptools import find_packages from setuptools import find_packages
@@ -32,40 +40,43 @@ EXAMPLES = [
TESTS = ["pytest"] TESTS = ["pytest"]
ALL = DOCS + DATASETS + EXAMPLES + TESTS ALL = DOCS + DATASETS + EXAMPLES + TESTS
setup(name="prototorch", setup(
version="0.2.0", name="prototorch",
description="Highly extensible, GPU-supported " version="0.3.0-dev0",
"Learning Vector Quantization (LVQ) toolbox " description="Highly extensible, GPU-supported "
"built using PyTorch and its nn API.", "Learning Vector Quantization (LVQ) toolbox "
long_description=long_description, "built using PyTorch and its nn API.",
long_description_content_type="text/markdown", long_description=long_description,
author="Jensun Ravichandran", long_description_content_type="text/markdown",
author_email="jjensun@gmail.com", author="Jensun Ravichandran",
url=PROJECT_URL, author_email="jjensun@gmail.com",
download_url=DOWNLOAD_URL, url=PROJECT_URL,
license="MIT", download_url=DOWNLOAD_URL,
install_requires=INSTALL_REQUIRES, license="MIT",
extras_require={ install_requires=INSTALL_REQUIRES,
"docs": DOCS, extras_require={
"datasets": DATASETS, "docs": DOCS,
"examples": EXAMPLES, "datasets": DATASETS,
"tests": TESTS, "examples": EXAMPLES,
"all": ALL, "tests": TESTS,
}, "all": ALL,
classifiers=[ },
"Development Status :: 2 - Pre-Alpha", classifiers=[
"Environment :: Console", "Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers", "Environment :: Console",
"Intended Audience :: Education", "Intended Audience :: Developers",
"Intended Audience :: Science/Research", "Intended Audience :: Education",
"License :: OSI Approved :: MIT License", "Intended Audience :: Science/Research",
"Natural Language :: English", "License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.6", "Natural Language :: English",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.7",
"Operating System :: OS Independent", "Programming Language :: Python :: 3.8",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Operating System :: OS Independent",
"Topic :: Software Development :: Libraries", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Libraries",
], "Topic :: Software Development :: Libraries :: Python Modules",
packages=find_packages()) ],
packages=find_packages(),
zip_safe=False,
)