31 Commits

Author SHA1 Message Date
Christoph
94fe4435a8 Bump version: 0.4.4 → 0.4.5 2021-05-27 09:58:25 +02:00
Alexander Engelsberger
c204bc8e1f integrate reviews from ChristophRaab:master 2021-05-27 09:43:02 +02:00
Alexander Engelsberger
00615ae837 refactored gtlvq from ChristophRaab:master 2021-05-27 09:40:42 +02:00
Jensun Ravichandran
9f5f0d12dd [BUGFIX] Parse dictionary distribution appropirately 2021-05-25 20:52:39 +02:00
Jensun Ravichandran
8a291f7bfb Overload distribution argument in component initializers
The component initializers behave differently based on the type of the
`distribution` argument. If it is a Python
[list](https://docs.python.org/3/tutorial/datastructures.html), it is assumed
that there are as many entries in this list as there are classes, and the number
at each location of this list describes the number of prototypes to be used for
that particular class. So, `[1, 1, 1]` implies that we have three classes with
one prototype per class. If it is a Python
[tuple](https://docs.python.org/3/tutorial/datastructures.html), it a shorthand
of `(num_classes, prototypes_per_class)` is assumed. If it is a Python
[dictionary](https://docs.python.org/3/tutorial/datastructures.html), the
key-value pairs describe the class label and the number of prototypes for that
class respectively. So, `{0: 2, 1: 2, 2: 2}` implies that we have three classes
with labels `{1, 2, 3}`, each equipped with two prototypes.
2021-05-25 20:05:29 +02:00
Alexander Engelsberger
21e3e3b82d Cache pip in CI 2021-05-25 16:43:48 +02:00
Alexander Engelsberger
a6bd6e130a Add subpackages into prototorch namespace. 2021-05-25 16:40:53 +02:00
Alexander Engelsberger
fcdfa52892 Ignore artiifacts folder 2021-05-25 16:40:34 +02:00
Alexander Engelsberger
73e6fe384e Use 'num_' in all variable names 2021-05-25 15:57:05 +02:00
Alexander Engelsberger
aff7a385a3 Use dict for distribution
This change allows the use of LightningCLI.
2021-05-21 17:10:02 +02:00
Jensun Ravichandran
1e23ba05fa Add test_components 2021-05-21 16:22:02 +02:00
Alexander Engelsberger
ee30d4da5b [BUGFIX] Initializers can handle Dataloaders now 2021-05-21 16:00:20 +02:00
Alexander Engelsberger
14508f0600 [DOC] Small improvements 2021-05-21 15:59:44 +02:00
Jensun Ravichandran
e3f8828da4 Accept dataloaders for component initialization 2021-05-21 11:59:57 +02:00
Jensun Ravichandran
30adbf705c Update dependencies 2021-05-20 11:44:53 +02:00
Jensun Ravichandran
ee42fd68b1 NumpyDataset now has data and targets properties 2021-05-18 19:38:40 +02:00
Jensun Ravichandran
736d9a6349 Rename PositionAwareInitializer to DataAwareInitializer
Also, add the aliases `Zeros` and `Ones`.
2021-05-18 19:37:25 +02:00
Alexander Engelsberger
0055e15bc1 [DOC] Fix iris data dimension 2021-05-18 18:57:03 +02:00
Alexander Engelsberger
b2e1df7308 Improve dataset documentation. 2021-05-18 18:54:43 +02:00
Jensun Ravichandran
b935e9caf3 Update _get_dp_dm 2021-05-18 13:09:11 +02:00
Jensun Ravichandran
503ef0e05f Cleanup components 2021-05-17 16:58:57 +02:00
Jensun Ravichandran
dc6248413c Apply transformations in component initializers 2021-05-17 16:58:22 +02:00
Jensun Ravichandran
e73b70ceb7 Minor aesthetic change 2021-05-17 16:57:41 +02:00
Jensun Ravichandran
639198e774 Update Iris dataset 2021-05-17 16:57:13 +02:00
Alexander Engelsberger
768d969f89 Device agnostic initialization of components. 2021-05-13 15:21:04 +02:00
Alexander Engelsberger
aec422c277 Remove copy paste error from documentation. 2021-05-13 11:56:38 +02:00
Jensun Ravichandran
6c14170de6 [BUGFIX] Fix typo 2021-05-12 16:31:22 +02:00
Jensun Ravichandran
36a330aa66 Update component initializers 2021-05-12 16:28:55 +02:00
Jensun Ravichandran
acd4ac6a86 Flatten tensors before computing distances 2021-05-12 16:28:34 +02:00
Jensun Ravichandran
abe64cfe8f Merge pull request #3 from dmoebius-dm/dev
Removed wrong parameter.
2021-05-12 16:23:27 +02:00
Danny
caae95d01d Removed wrong parameter. 2021-05-12 16:00:01 +02:00
26 changed files with 372 additions and 250 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.4
current_version = 0.4.5
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

3
.gitignore vendored
View File

@@ -154,4 +154,5 @@ scratch*
# End of https://www.gitignore.io/api/visualstudiocode
.vscode/
reports
reports
artifacts

View File

@@ -4,7 +4,9 @@ language: python
python: 3.8
cache:
directories:
- "$HOME/.cache/pip"
- "./tests/artifacts"
- "$HOME/datasets"
install:
- pip install .[all] --progress-bar off

View File

@@ -1,13 +1,24 @@
.. ProtoFlow API Reference
.. ProtoTorch API Reference
ProtoFlow API Reference
ProtoTorch API Reference
======================================
Datasets
--------------------------------------
Common Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: prototorch.datasets
:members:
:undoc-members:
Abstract Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Abstract Datasets are used to build your own datasets.
.. autoclass:: prototorch.datasets.abstract.NumpyDataset
:members:
Functions
--------------------------------------

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "0.4.4"
release = "0.4.5"
# -- General configuration ---------------------------------------------------
@@ -46,6 +46,7 @@ extensions = [
"sphinx.ext.viewcode",
"sphinx_rtd_theme",
"sphinxcontrib.katex",
'sphinx_autodoc_typehints',
]
# katex_prerender = True
@@ -179,6 +180,9 @@ texinfo_documents = [
intersphinx_mapping = {
"python": ("https://docs.python.org/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"torch": ('https://pytorch.org/docs/stable/', None),
"pytorch_lightning":
("https://pytorch-lightning.readthedocs.io/en/stable/", None),
}
# -- Options for Epub output ----------------------------------------------

View File

@@ -3,14 +3,13 @@
import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
# Prepare and preprocess the data
scaler = StandardScaler()
@@ -28,7 +27,7 @@ class Model(torch.nn.Module):
self.proto_layer = Prototypes1D(
input_dim=2,
prototypes_per_class=3,
nclasses=3,
num_classes=3,
prototype_initializer="stratified_random",
data=[x_train, y_train],
)

View File

@@ -2,13 +2,12 @@
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import get_legend_handles
from torch.utils.data import DataLoader
# Prepare the dataset and dataloader
train_data = Tecator(root="./artifacts", train=True)
@@ -23,7 +22,7 @@ class Model(torch.nn.Module):
self.p1 = Prototypes1D(
input_dim=100,
prototypes_per_class=2,
nclasses=2,
num_classes=2,
prototype_initializer="stratified_random",
data=[x, y],
)

View File

@@ -12,14 +12,13 @@ import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from prototorch.functions.helper import calculate_prototype_accuracy
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.models import GTLVQ
from torchvision import transforms
# Parameters and options
n_epochs = 50
num_epochs = 50
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.1
@@ -141,7 +140,7 @@ optimizer = torch.optim.Adam(
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
# Training loop
for epoch in range(n_epochs):
for epoch in range(num_epochs):
for batch_idx, (x_train, y_train) in enumerate(train_loader):
model.train()
x_train, y_train = x_train.to(device), y_train.to(device)
@@ -161,7 +160,7 @@ for epoch in range(n_epochs):
if batch_idx % log_interval == 0:
acc = calculate_prototype_accuracy(distances, y_train, plabels)
print(
f"Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
Train Acc: {acc.item():02.02f}")
# Test

View File

@@ -1,21 +1,24 @@
"""ProtoTorch package."""
import pkgutil
import pkg_resources
from . import components, datasets, functions, modules, utils
from .datasets import *
# Core Setup
__version__ = "0.4.4"
__version__ = "0.4.5"
__all_core__ = [
"datasets",
"functions",
"modules",
"components",
"utils",
]
from .datasets import *
# Plugin Loader
import pkgutil
import pkg_resources
__path__ = pkgutil.extend_path(__path__, __name__)

View File

@@ -1,36 +1,37 @@
"""ProtoTorch components modules."""
import warnings
from typing import Tuple
import torch
from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer,
CustomLabelsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer,
ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter
class Components(torch.nn.Module):
"""Components is a set of learnable Tensors."""
def __init__(self,
number_of_components=None,
num_components=None,
initializer=None,
*,
initialized_components=None,
dtype=torch.float32):
initialized_components=None):
super().__init__()
self.num_components = num_components
# Ignore all initialization settings if initialized_components is given.
if initialized_components is not None:
self._components = Parameter(initialized_components)
if number_of_components is not None or initializer is not None:
self.register_parameter("_components",
Parameter(initialized_components))
if num_components is not None or initializer is not None:
wmsg = "Arguments ignored while initializing Components"
warnings.warn(wmsg)
else:
self._initialize_components(number_of_components, initializer)
self._initialize_components(initializer)
def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer):
@@ -39,15 +40,15 @@ class Components(torch.nn.Module):
f"You have provided: {initializer=} instead."
raise TypeError(emsg)
def _initialize_components(self, number_of_components, initializer):
def _initialize_components(self, initializer):
self._precheck_initializer(initializer)
self._components = Parameter(
initializer.generate(number_of_components))
_components = initializer.generate(self.num_components)
self.register_parameter("_components", Parameter(_components))
@property
def components(self):
"""Tensor containing the component tensors."""
return self._components.detach().cpu()
return self._components.detach()
def forward(self):
return self._components
@@ -71,33 +72,40 @@ class LabeledComponents(Components):
super().__init__(initialized_components=components)
self._labels = component_labels
else:
self._initialize_labels(distribution)
super().__init__(number_of_components=len(self._labels),
initializer=initializer)
_labels = self._initialize_labels(distribution)
super().__init__(len(_labels), initializer=initializer)
self.register_buffer("_labels", _labels)
def _initialize_components(self, number_of_components, initializer):
def _initialize_components(self, initializer):
if isinstance(initializer, ClassAwareInitializer):
self._precheck_initializer(initializer)
self._components = Parameter(
initializer.generate(number_of_components, self.distribution))
_components = initializer.generate(self.num_components,
self.distribution)
self.register_parameter("_components", Parameter(_components))
else:
super()._initialize_components(self, number_of_components,
initializer)
super()._initialize_components(initializer)
def _initialize_labels(self, distribution):
if type(distribution) == tuple:
if type(distribution) == dict:
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
labels = CustomLabelsInitializer(distribution)
elif type(distribution) == tuple:
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif type(distribution) == list:
labels = UnequalLabelsInitializer(distribution)
self.distribution = labels.distribution
self._labels = labels.generate()
return labels.generate()
@property
def component_labels(self):
"""Tensor containing the component tensors."""
return self._labels.detach().cpu()
return self._labels.detach()
def forward(self):
return super().forward(), self._labels
@@ -124,20 +132,21 @@ class ReasoningComponents(Components):
*,
initialized_components=None):
if initialized_components is not None:
super().__init__(initialized_components=initialized_components[0])
self._reasonings = initialized_components[1]
components, reasonings = initialized_components
super().__init__(initialized_components=components)
self.register_parameter("_reasonings", reasonings)
else:
self._initialize_reasonings(reasonings)
super().__init__(number_of_components=len(self._reasonings),
initializer=initializer)
super().__init__(len(self._reasonings), initializer=initializer)
def _initialize_reasonings(self, reasonings):
if type(reasonings) == tuple:
num_classes, number_of_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes,
number_of_components)
num_classes, num_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
self._reasonings = reasonings.generate()
_reasonings = reasonings.generate()
self.register_parameter("_reasonings", _reasonings)
@property
def reasonings(self):
@@ -146,7 +155,7 @@ class ReasoningComponents(Components):
Dimension NxCx2
"""
return self._reasonings.detach().cpu()
return self._reasonings.detach()
def forward(self):
return super().forward(), self._reasonings

View File

@@ -7,12 +7,18 @@ import torch
from torch.utils.data import DataLoader, Dataset
def parse_init_arg(arg):
if isinstance(arg, Dataset):
data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
# data = data.view(len(arg), -1) # flatten
def parse_data_arg(data_arg):
if isinstance(data_arg, Dataset):
data_arg = DataLoader(data_arg, batch_size=len(data_arg))
if isinstance(data_arg, DataLoader):
data = torch.tensor([])
labels = torch.tensor([])
for x, y in data_arg:
data = torch.cat([data, x])
labels = torch.cat([labels, y])
else:
data, labels = arg
data, labels = data_arg
if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}."
warnings.warn(wmsg)
@@ -63,19 +69,19 @@ class UniformInitializer(DimensionAwareInitializer):
return torch.ones(gen_dims).uniform_(self.min, self.max)
class PositionAwareInitializer(ComponentsInitializer):
def __init__(self, positions):
class DataAwareInitializer(ComponentsInitializer):
def __init__(self, data):
super().__init__()
self.data = positions
self.data = data
class SelectionInitializer(PositionAwareInitializer):
class SelectionInitializer(DataAwareInitializer):
def generate(self, length):
indices = torch.LongTensor(length).random_(0, len(self.data))
return self.data[indices]
class MeanInitializer(PositionAwareInitializer):
class MeanInitializer(DataAwareInitializer):
def generate(self, length):
mean = torch.mean(self.data, dim=0)
repeat_dim = [length] + [1] * len(mean.shape)
@@ -83,12 +89,14 @@ class MeanInitializer(PositionAwareInitializer):
class ClassAwareInitializer(ComponentsInitializer):
def __init__(self, arg):
def __init__(self, data, transform=torch.nn.Identity()):
super().__init__()
data, labels = parse_init_arg(arg)
data, labels = parse_data_arg(data)
self.data = data
self.labels = labels
self.transform = transform
self.clabels = torch.unique(self.labels)
self.num_classes = len(self.clabels)
@@ -96,15 +104,24 @@ class ClassAwareInitializer(ComponentsInitializer):
if not dist:
per_class = length // self.num_classes
dist = self.num_classes * [per_class]
if type(dist) == dict:
dist = dist.values()
samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist)
]
return torch.vstack(samples_list)
out = torch.vstack(samples_list)
with torch.no_grad():
out = self.transform(out)
return out
def __del__(self):
del self.data
del self.labels
class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, arg):
super().__init__(arg)
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
self.initializers = []
for clabel in self.clabels:
@@ -118,8 +135,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
class StratifiedSelectionInitializer(ClassAwareInitializer):
def __init__(self, arg, *, noise=None):
super().__init__(arg)
def __init__(self, data, noise=None, **kwargs):
super().__init__(data, **kwargs)
self.noise = noise
self.initializers = []
@@ -128,7 +145,10 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
class_initializer = SelectionInitializer(class_data)
self.initializers.append(class_initializer)
def add_noise(self, x):
def add_noise_v1(self, x):
return x + self.noise
def add_noise_v2(self, x):
"""Shifts some dimensions of the data randomly."""
n1 = torch.rand_like(x)
n2 = torch.rand_like(x)
@@ -138,8 +158,7 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
def generate(self, length, dist=[]):
samples = self._get_samples_from_initializer(length, dist)
if self.noise is not None:
# samples = self.add_noise(samples)
samples = samples + self.noise
samples = self.add_noise_v1(samples)
return samples
@@ -157,10 +176,13 @@ class UnequalLabelsInitializer(LabelsInitializer):
def distribution(self):
return self.dist
def generate(self):
clabels = range(len(self.dist))
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
return torch.tensor(labels)
def generate(self, clabels=None, dist=None):
if not clabels:
clabels = range(len(self.dist))
if not dist:
dist = self.dist
labels = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
return torch.LongTensor(labels)
class EqualLabelsInitializer(LabelsInitializer):
@@ -176,6 +198,13 @@ class EqualLabelsInitializer(LabelsInitializer):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
class CustomLabelsInitializer(UnequalLabelsInitializer):
def generate(self):
clabels = list(self.dist.keys())
dist = list(self.dist.values())
return super().generate(clabels, dist)
# Reasonings
class ReasoningsInitializer:
def generate(self, length):
@@ -195,3 +224,5 @@ class ZeroReasoningsInitializer(ReasoningsInitializer):
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
SMI = StratifiedMeanInitializer
Random = RandomInitializer = UniformInitializer
Zeros = ZerosInitializer
Ones = OnesInitializer

View File

@@ -4,3 +4,5 @@ from .abstract import NumpyDataset
from .iris import Iris
from .spiral import Spiral
from .tecator import Tecator
__all__ = ['Iris', 'Spiral', 'Tecator']

View File

@@ -14,8 +14,10 @@ import torch
class NumpyDataset(torch.utils.data.TensorDataset):
"""Create a PyTorch TensorDataset from NumPy arrays."""
def __init__(self, *arrays):
tensors = [torch.Tensor(arr) for arr in arrays]
def __init__(self, data, targets):
self.data = data
self.targets = targets
tensors = [torch.Tensor(data), torch.Tensor(targets)]
super().__init__(*tensors)

View File

@@ -5,11 +5,36 @@ URL:
"""
from typing import Sequence
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import load_iris
class Iris(NumpyDataset):
def __init__(self):
"""
Iris Dataset by Ronald Fisher introduced in 1936.
The dataset contains four measurements from flowers of three species of iris.
.. list-table:: Iris
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 4
- 3
- 150
- 0
- 0
:param dims: select a subset of dimensions
"""
def __init__(self, dims: Sequence[int] = None):
x, y = load_iris(return_X_y=True)
if dims:
x = x[:, dims]
super().__init__(x, y)

View File

@@ -4,18 +4,22 @@ import numpy as np
import torch
def make_spiral(n_samples=500, noise=0.3):
def make_spiral(num_samples=500, noise=0.3):
"""Generates the Spiral Dataset.
For use in Prototorch use `prototorch.datasets.Spiral` instead.
"""
def get_samples(n, delta_t):
points = []
for i in range(n):
r = i / n_samples * 5
r = i / num_samples * 5
t = 1.75 * i / n * 2 * np.pi + delta_t
x = r * np.sin(t) + np.random.rand(1) * noise
y = r * np.cos(t) + np.random.rand(1) * noise
points.append([x, y])
return points
n = n_samples // 2
n = num_samples // 2
positive = get_samples(n=n, delta_t=0)
negative = get_samples(n=n, delta_t=np.pi)
x = np.concatenate(
@@ -27,7 +31,27 @@ def make_spiral(n_samples=500, noise=0.3):
class Spiral(torch.utils.data.TensorDataset):
"""Spiral dataset for binary classification."""
def __init__(self, n_samples=500, noise=0.3):
x, y = make_spiral(n_samples, noise)
"""Spiral dataset for binary classification.
This datasets consists of two spirals of two different classes.
.. list-table:: Spiral
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 2
- 2
- num_samples
- 0
- 0
:param num_samples: number of random samples
:param noise: noise added to the spirals
"""
def __init__(self, num_samples: int = 500, noise: float = 0.3):
x, y = make_spiral(num_samples, noise)
super().__init__(torch.Tensor(x), torch.LongTensor(y))

View File

@@ -40,15 +40,29 @@ import os
import numpy as np
import torch
from torchvision.datasets.utils import download_file_from_google_drive
from prototorch.datasets.abstract import ProtoDataset
from torchvision.datasets.utils import download_file_from_google_drive
class Tecator(ProtoDataset):
"""
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__
for classification.
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__ for classification.
The dataset contains wavelength measurements of meat.
.. list-table:: Tecator
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 100
- 2
- 129
- 43
- 43
"""
_resources = [

View File

@@ -5,12 +5,12 @@ import torch
def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0)
nclasses = clabels.size()[0]
if distances.size()[1] == nclasses:
num_classes = clabels.size()[0]
if distances.size()[1] == num_classes:
# skip if only one prototype per class
return distances
batch_size = distances.size()[0]
winning_distances = torch.zeros(nclasses, batch_size)
winning_distances = torch.zeros(num_classes, batch_size)
inf = torch.full_like(distances.T, fill_value=float("inf"))
# distances_to_wpluses = torch.where(matcher, distances, inf)
for i, cl in enumerate(clabels):
@@ -18,7 +18,7 @@ def stratified_min(distances, labels):
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
if labels.ndim == 2:
# if the labels are one-hot vectors
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
cdists = torch.where(matcher, distances.T, inf).T
winning_distances[i] = torch.min(cdists, dim=1,
keepdim=True).values.squeeze()

View File

@@ -2,9 +2,8 @@
import numpy as np
import torch
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
equal_int_shape)
equal_int_shape, get_flat)
def squared_euclidean_distance(x, y):
@@ -12,12 +11,10 @@ def squared_euclidean_distance(x, y):
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
**Alias:**
``prototorch.functions.distances.sed``
"""
x, y = get_flat(x, y)
expanded_x = x.unsqueeze(dim=1)
batchwise_difference = y - expanded_x
differences_raised = torch.pow(batchwise_difference, 2)
@@ -30,18 +27,17 @@ def euclidean_distance(x, y):
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
:param `torch.tensor` x: Input Tensor of shape :math:`X \times N`
:param `torch.tensor` y: Input Tensor of shape :math:`Y \times N`
:returns: Distance Tensor of shape :math:`X \times Y`
:rtype: `torch.tensor`
"""
x, y = get_flat(x, y)
distances_raised = squared_euclidean_distance(x, y)
distances = torch.sqrt(distances_raised)
return distances
def euclidean_distance_v2(x, y):
x, y = get_flat(x, y)
diff = y - x.unsqueeze(1)
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
@@ -62,10 +58,9 @@ def lpnorm_distance(x, y, p):
Calls ``torch.cdist``
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
:param p: p parameter of the lp norm
"""
x, y = get_flat(x, y)
distances = torch.cdist(x, y, p=p)
return distances
@@ -75,10 +70,9 @@ def omega_distance(x, y, omega):
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
:param `torch.tensor` omega: Two dimensional matrix
"""
x, y = get_flat(x, y)
projected_x = x @ omega
projected_y = y @ omega
distances = squared_euclidean_distance(projected_x, projected_y)
@@ -90,10 +84,9 @@ def lomega_distance(x, y, omegas):
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
:param `torch.tensor` omegas: Three dimensional matrix
"""
x, y = get_flat(x, y)
projected_x = x @ omegas
projected_y = torch.diagonal(y @ omegas).T
expanded_y = torch.unsqueeze(projected_y, dim=1)

View File

@@ -1,6 +1,11 @@
import torch
def get_flat(*args):
rv = [x.view(x.size(0), -1) for x in args]
return rv
def calculate_prototype_accuracy(y_pred, y_true, plabels):
"""Computes the accuracy of a prototype based model.
via Winner-Takes-All rule.

View File

@@ -15,59 +15,59 @@ def register_initializer(function):
def labels_from(distribution, one_hot=True):
"""Takes a distribution tensor and returns a labels tensor."""
nclasses = distribution.shape[0]
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
num_classes = distribution.shape[0]
llist = [[i] * n for i, n in zip(range(num_classes), distribution)]
# labels = [l for cl in llist for l in cl] # flatten the list of lists
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
plabels = torch.tensor(flat_llist, requires_grad=False)
if one_hot:
return torch.eye(nclasses)[plabels]
return torch.eye(num_classes)[plabels]
return plabels
@register_initializer
def ones(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution)
protos = torch.ones(nprotos, *x_train.shape[1:])
num_protos = torch.sum(prototype_distribution)
protos = torch.ones(num_protos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels
@register_initializer
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution)
protos = torch.zeros(nprotos, *x_train.shape[1:])
num_protos = torch.sum(prototype_distribution)
protos = torch.zeros(num_protos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels
@register_initializer
def rand(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution)
protos = torch.rand(nprotos, *x_train.shape[1:])
num_protos = torch.sum(prototype_distribution)
protos = torch.rand(num_protos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels
@register_initializer
def randn(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution)
protos = torch.randn(nprotos, *x_train.shape[1:])
num_protos = torch.sum(prototype_distribution)
protos = torch.randn(num_protos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels
@register_initializer
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution)
num_protos = torch.sum(prototype_distribution)
pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim)
protos = torch.empty(num_protos, pdim)
plabels = labels_from(prototype_distribution, one_hot)
for i, label in enumerate(plabels):
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
if one_hot:
nclasses = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
num_classes = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
xl = x_train[matcher]
mean_xl = torch.mean(xl, dim=0)
protos[i] = mean_xl
@@ -81,15 +81,15 @@ def stratified_random(x_train,
prototype_distribution,
one_hot=True,
epsilon=1e-7):
nprotos = torch.sum(prototype_distribution)
num_protos = torch.sum(prototype_distribution)
pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim)
protos = torch.empty(num_protos, pdim)
plabels = labels_from(prototype_distribution, one_hot)
for i, label in enumerate(plabels):
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
if one_hot:
nclasses = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
num_classes = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
xl = x_train[matcher]
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
random_xl = xl[rand_index]

View File

@@ -8,12 +8,12 @@ def _get_matcher(targets, labels):
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
if labels.ndim == 2:
# if the labels are one-hot vectors
nclasses = targets.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
num_classes = targets.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
return matcher
def _get_dp_dm(distances, targets, plabels):
def _get_dp_dm(distances, targets, plabels, with_indices=False):
"""Returns the d+ and d- values for a batch of distances."""
matcher = _get_matcher(targets, plabels)
not_matcher = torch.bitwise_not(matcher)
@@ -21,9 +21,11 @@ def _get_dp_dm(distances, targets, plabels):
inf = torch.full_like(distances, fill_value=float("inf"))
d_matching = torch.where(matcher, distances, inf)
d_unmatching = torch.where(not_matcher, distances, inf)
dp = torch.min(d_matching, dim=1, keepdim=True).values
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
return dp, dm
dp = torch.min(d_matching, dim=-1, keepdim=True)
dm = torch.min(d_unmatching, dim=-1, keepdim=True)
if with_indices:
return dp, dm
return dp.values, dm.values
def glvq_loss(distances, target_labels, prototype_labels):
@@ -47,10 +49,11 @@ def lvq1_loss(distances, target_labels, prototype_labels):
def lvq21_loss(distances, target_labels, prototype_labels):
"""LVQ2.1 loss function with support for one-hot labels.
See Section 4 [Sado&Yamada]
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
"""
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = dp - dm
return mu
return mu

View File

@@ -1,12 +1,8 @@
import torch
from torch import nn
from prototorch.functions.distances import (euclidean_distance_matrix,
tangent_distance)
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.distances import euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D
from torch import nn
class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization
@@ -80,22 +76,11 @@ class GTLVQ(nn.Module):
super(GTLVQ, self).__init__()
self.num_protos = num_classes * prototypes_per_class
self.num_protos_class = prototypes_per_class
self.subspace_size = feature_dim if subspace_size is None else subspace_size
self.feature_dim = feature_dim
self.num_classes = num_classes
if subspace_data is None:
raise ValueError("Init Data must be specified!")
self.tpt = tangent_projection_type
with torch.no_grad():
if self.tpt == "local" or self.tpt == "local_proj":
self.init_local_subspace(subspace_data)
elif self.tpt == "global":
self.init_gobal_subspace(subspace_data, subspace_size)
else:
self.subspaces = None
# Hypothesis-Margin-Classifier
self.cls = Prototypes1D(
input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
@@ -104,21 +89,22 @@ class GTLVQ(nn.Module):
data=prototype_data,
)
def forward(self, x):
# Tangent Projection
if self.tpt == "local_proj":
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2))
dis, proj_x = self.local_tangent_projection(x_conform)
if subspace_data is None:
raise ValueError("Init Data must be specified!")
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
self.feature_dim)
return proj_x, dis
elif self.tpt == "local":
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2))
dis = tangent_distance(x_conform, self.cls.prototypes,
self.subspaces)
self.tpt = tangent_projection_type
with torch.no_grad():
if self.tpt == "local":
self.init_local_subspace(subspace_data, subspace_size,
self.num_protos)
elif self.tpt == "global":
self.init_gobal_subspace(subspace_data, subspace_size)
else:
self.subspaces = None
def forward(self, x):
if self.tpt == "local":
dis = self.local_tangent_distances(x)
elif self.tpt == "gloabl":
dis = self.global_tangent_distances(x)
else:
@@ -131,16 +117,14 @@ class GTLVQ(nn.Module):
_, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = subspace[:, :num_subspaces]
self.subspaces = (torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True))
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
def init_local_subspace(self, data):
_, _, v = torch.svd(data)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
self.num_protos, 0)
self.subspaces = (torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True))
def init_local_subspace(self, data,num_subspaces,num_protos):
data = data - torch.mean(data,dim=0)
_,_,v = torch.svd(data,some=False)
v = v[:,:num_subspaces]
subspaces = v.unsqueeze(0).repeat_interleave(num_protos,0)
self.subspaces = nn.Parameter(subspaces,requires_grad=True)
def global_tangent_distances(self, x):
# Tangent Projection
@@ -151,33 +135,21 @@ class GTLVQ(nn.Module):
# Euclidean Distance
return euclidean_distance_matrix(x, projected_prototypes)
def local_tangent_projection(self, signals):
# Note: subspaces is always assumed as transposed and must be orthogonal!
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
# subspace should be orthogonalized
# Origin Source Code
# Origin Author:
protos = self.cls.prototypes
subspaces = self.subspaces
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
_, proto_int_shape = _int_and_mixed_shape(protos)
def local_tangent_distances(self, x):
# check if the shapes are correct
_check_shapes(signal_int_shape, proto_int_shape)
# Tangent Data Projections
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
data = signals.squeeze(2).permute([1, 0, 2])
projected_data = torch.bmm(data, subspaces)
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
diff = projected_data - projected_protos
projected_diff = torch.reshape(
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
signal_shape[3:])
diss = torch.norm(projected_diff, 2, dim=-1)
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
# Tangent Distance
x = x.unsqueeze(1).expand(x.size(0), self.cls.prototypes.size(0),
x.size(-1))
protos = self.cls.prototypes.unsqueeze(0).expand(
x.size(0), self.cls.prototypes.size(0), x.size(-1))
projectors = torch.eye(
self.subspaces.shape[-2], device=x.device) - torch.bmm(
self.subspaces, self.subspaces.permute([0, 2, 1]))
diff = (x - protos)
diff = diff.permute([1, 0, 2])
diff = torch.bmm(diff, projectors)
diff = torch.norm(diff,2,dim=-1).T
return diff
def get_parameters(self):
return {

View File

@@ -3,7 +3,6 @@
import warnings
import torch
from prototorch.functions.initializers import get_initializer
@@ -53,13 +52,13 @@ class Prototypes1D(_Prototypes):
raise NameError("`input_dim` required if "
"no `data` is provided.")
if prototype_distribution:
kwargs_nclasses = sum(prototype_distribution)
kwargs_num_classes = sum(prototype_distribution)
else:
if "nclasses" not in kwargs:
if "num_classes" not in kwargs:
raise NameError("`prototype_distribution` required if "
"both `data` and `nclasses` are not "
"both `data` and `num_classes` are not "
"provided.")
kwargs_nclasses = kwargs.pop("nclasses")
kwargs_num_classes = kwargs.pop("num_classes")
input_dim = kwargs.pop("input_dim")
if prototype_initializer in [
"stratified_mean", "stratified_random"
@@ -68,18 +67,18 @@ class Prototypes1D(_Prototypes):
f"`prototype_initializer`: `{prototype_initializer}` "
"requires `data`, but `data` is not provided. "
"Using randomly generated data instead.")
x_train = torch.rand(kwargs_nclasses, input_dim)
y_train = torch.arange(kwargs_nclasses)
x_train = torch.rand(kwargs_num_classes, input_dim)
y_train = torch.arange(kwargs_num_classes)
if one_hot_labels:
y_train = torch.eye(kwargs_nclasses)[y_train]
y_train = torch.eye(kwargs_num_classes)[y_train]
data = [x_train, y_train]
x_train, y_train = data
x_train = torch.as_tensor(x_train).type(dtype)
y_train = torch.as_tensor(y_train).type(torch.int)
nclasses = torch.unique(y_train, dim=-1).shape[-1]
num_classes = torch.unique(y_train, dim=-1).shape[-1]
if nclasses == 1:
if num_classes == 1:
warnings.warn("Are you sure about having one class only?")
if x_train.ndim != 2:
@@ -105,19 +104,20 @@ class Prototypes1D(_Prototypes):
"not match data dimension "
f"`data[0].shape[1]`={x_train.shape[1]}")
# Verify the number of classes if `nclasses` is provided
if "nclasses" in kwargs:
kwargs_nclasses = kwargs.pop("nclasses")
if kwargs_nclasses != nclasses:
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
"not match data labels "
"`torch.unique(data[1]).shape[0]`"
f"={nclasses}")
# Verify the number of classes if `num_classes` is provided
if "num_classes" in kwargs:
kwargs_num_classes = kwargs.pop("num_classes")
if kwargs_num_classes != num_classes:
raise ValueError(
f"Provided `num_classes={kwargs_num_classes}` does "
"not match data labels "
"`torch.unique(data[1]).shape[0]`"
f"={num_classes}")
super().__init__(**kwargs)
if not prototype_distribution:
prototype_distribution = [prototypes_per_class] * nclasses
prototype_distribution = [prototypes_per_class] * num_classes
with torch.no_grad():
self.prototype_distribution = torch.tensor(prototype_distribution)

View File

@@ -20,10 +20,10 @@ INSTALL_REQUIRES = [
"torch>=1.3.1",
"torchvision>=0.5.0",
"numpy>=1.9.1",
"sklearn",
]
DATASETS = [
"requests",
"sklearn",
"tqdm",
]
DEV = ["bumpversion"]
@@ -32,9 +32,9 @@ DOCS = [
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-katex",
"sphinx-autodoc-typehints",
]
EXAMPLES = [
"sklearn",
"matplotlib",
"torchinfo",
]
@@ -43,7 +43,7 @@ ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup(
name="prototorch",
version="0.4.4",
version="0.4.5",
description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.",

25
tests/test_components.py Normal file
View File

@@ -0,0 +1,25 @@
"""ProtoTorch components test suite."""
import prototorch as pt
import torch
def test_labcomps_zeros_init():
protos = torch.zeros(3, 2)
c = pt.components.LabeledComponents(
distribution=[1, 1, 1],
initializer=pt.components.Zeros(2),
)
assert (c.components == protos).any() == True
def test_labcomps_warmstart():
protos = torch.randn(3, 2)
plabels = torch.tensor([1, 2, 3])
c = pt.components.LabeledComponents(
distribution=[1, 1, 1],
initializer=None,
initialized_components=[protos, plabels],
)
assert (c.components == protos).any() == True
assert (c.component_labels == plabels).any() == True

View File

@@ -4,7 +4,6 @@ import unittest
import numpy as np
import torch
from prototorch.modules import losses, prototypes
@@ -18,20 +17,20 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_init_without_input_dim(self):
with self.assertRaises(NameError):
_ = prototypes.Prototypes1D(nclasses=2)
_ = prototypes.Prototypes1D(num_classes=2)
def test_prototypes1d_init_without_nclasses(self):
def test_prototypes1d_init_without_num_classes(self):
with self.assertRaises(NameError):
_ = prototypes.Prototypes1D(input_dim=1)
def test_prototypes1d_init_with_nclasses_1(self):
def test_prototypes1d_init_with_num_classes_1(self):
with self.assertWarns(UserWarning):
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
_ = prototypes.Prototypes1D(num_classes=1, input_dim=1)
def test_prototypes1d_init_without_pdist(self):
p1 = prototypes.Prototypes1D(
input_dim=6,
nclasses=2,
num_classes=2,
prototypes_per_class=4,
prototype_initializer="ones",
)
@@ -60,7 +59,7 @@ class TestPrototypes(unittest.TestCase):
with self.assertWarns(UserWarning):
_ = prototypes.Prototypes1D(
input_dim=3,
nclasses=2,
num_classes=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=None,
@@ -81,7 +80,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_init_without_inputdim_with_data(self):
_ = prototypes.Prototypes1D(
nclasses=2,
num_classes=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[[[1.0], [0.0]], [1, 0]],
@@ -89,7 +88,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_init_with_int_data(self):
_ = prototypes.Prototypes1D(
nclasses=2,
num_classes=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[[[1], [0]], [1, 0]],
@@ -98,7 +97,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_init_one_hot_without_data(self):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
num_classes=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=None,
@@ -112,7 +111,7 @@ class TestPrototypes(unittest.TestCase):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
num_classes=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
@@ -126,7 +125,7 @@ class TestPrototypes(unittest.TestCase):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
num_classes=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=([[0.0], [1.0]], [0, 1]),
@@ -141,7 +140,7 @@ class TestPrototypes(unittest.TestCase):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
num_classes=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=([[0.0], [1.0]], [[0], [1]]),
@@ -151,7 +150,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_init_with_int_dtype(self):
with self.assertRaises(RuntimeError):
_ = prototypes.Prototypes1D(
nclasses=2,
num_classes=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[[[1], [0]], [1, 0]],
@@ -161,7 +160,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_inputndim_with_data(self):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(input_dim=1,
nclasses=1,
num_classes=1,
prototypes_per_class=1,
data=[[1.0], [1]])
@@ -169,20 +168,20 @@ class TestPrototypes(unittest.TestCase):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=2,
nclasses=2,
num_classes=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[[[1.0], [0.0]], [1, 0]],
)
def test_prototypes1d_nclasses_with_data(self):
"""Test ValueError raise if provided `nclasses` is not the same
def test_prototypes1d_num_classes_with_data(self):
"""Test ValueError raise if provided `num_classes` is not the same
as the one computed from the provided `data`.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=1,
num_classes=1,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[[[1.0], [2.0]], [1, 2]],
@@ -220,7 +219,7 @@ class TestPrototypes(unittest.TestCase):
p1 = prototypes.Prototypes1D(
input_dim=99,
nclasses=2,
num_classes=2,
prototypes_per_class=1,
prototype_initializer=my_initializer,
)