6 Commits

Author SHA1 Message Date
Alexander Engelsberger
09c80e2d54 Merge branch 'master' into kernel_distances 2021-05-11 16:10:56 +02:00
Alexander Engelsberger
65e0637b17 Fix RBF Kernel Dimensions. 2021-04-27 17:58:05 +02:00
Alexander Engelsberger
209f9e641b Fix kernel dimensions. 2021-04-27 16:56:56 +02:00
Alexander Engelsberger
ba537fe1d5 Automatic formatting. 2021-04-27 15:43:10 +02:00
Alexander Engelsberger
b0cd2de18e Batch Kernel. [Ineficient] 2021-04-27 15:38:34 +02:00
Alexander Engelsberger
7d353f5b5a Kernel Distances. 2021-04-27 12:06:15 +02:00
32 changed files with 937 additions and 421 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.5.0 current_version = 0.4.2
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

3
.gitignore vendored
View File

@@ -154,5 +154,4 @@ scratch*
# End of https://www.gitignore.io/api/visualstudiocode # End of https://www.gitignore.io/api/visualstudiocode
.vscode/ .vscode/
reports reports
artifacts

View File

@@ -4,9 +4,7 @@ language: python
python: 3.8 python: 3.8
cache: cache:
directories: directories:
- "$HOME/.cache/pip"
- "./tests/artifacts" - "./tests/artifacts"
- "$HOME/datasets"
install: install:
- pip install .[all] --progress-bar off - pip install .[all] --progress-bar off

View File

@@ -1,10 +1,5 @@
# ProtoTorch Releases # ProtoTorch Releases
## Release 0.5.0
- Breaking: Removed deprecated `prototorch.modules.Prototypes1D`
- Use `prototorch.components.LabeledComponents` instead
## Release 0.2.0 ## Release 0.2.0
### Includes ### Includes

View File

@@ -1,24 +1,13 @@
.. ProtoTorch API Reference .. ProtoFlow API Reference
ProtoTorch API Reference ProtoFlow API Reference
====================================== ======================================
Datasets Datasets
-------------------------------------- --------------------------------------
Common Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: prototorch.datasets .. automodule:: prototorch.datasets
:members: :members:
:undoc-members:
Abstract Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Abstract Datasets are used to build your own datasets.
.. autoclass:: prototorch.datasets.abstract.NumpyDataset
:members:
Functions Functions
-------------------------------------- --------------------------------------

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.5.0" release = "0.4.2"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
@@ -46,7 +46,6 @@ extensions = [
"sphinx.ext.viewcode", "sphinx.ext.viewcode",
"sphinx_rtd_theme", "sphinx_rtd_theme",
"sphinxcontrib.katex", "sphinxcontrib.katex",
'sphinx_autodoc_typehints',
] ]
# katex_prerender = True # katex_prerender = True
@@ -180,9 +179,6 @@ texinfo_documents = [
intersphinx_mapping = { intersphinx_mapping = {
"python": ("https://docs.python.org/", None), "python": ("https://docs.python.org/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None), "numpy": ("https://docs.scipy.org/doc/numpy/", None),
"torch": ('https://pytorch.org/docs/stable/', None),
"pytorch_lightning":
("https://pytorch-lightning.readthedocs.io/en/stable/", None),
} }
# -- Options for Epub output ---------------------------------------------- # -- Options for Epub output ----------------------------------------------

View File

@@ -3,14 +3,15 @@
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from torchinfo import summary from torchinfo import summary
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D
# Prepare and preprocess the data # Prepare and preprocess the data
scaler = StandardScaler() scaler = StandardScaler()
x_train, y_train = load_iris(return_X_y=True) x_train, y_train = load_iris(return_X_y=True)
@@ -24,17 +25,19 @@ class Model(torch.nn.Module):
def __init__(self): def __init__(self):
"""GLVQ model for training on 2D Iris data.""" """GLVQ model for training on 2D Iris data."""
super().__init__() super().__init__()
prototype_initializer = StratifiedMeanInitializer([x_train, y_train]) self.proto_layer = Prototypes1D(
prototype_distribution = {"num_classes": 3, "prototypes_per_class": 3} input_dim=2,
self.proto_layer = LabeledComponents( prototypes_per_class=3,
prototype_distribution, nclasses=3,
prototype_initializer, prototype_initializer="stratified_random",
data=[x_train, y_train],
) )
def forward(self, x): def forward(self, x):
prototypes, prototype_labels = self.proto_layer() protos = self.proto_layer.prototypes
distances = euclidean_distance(x, prototypes) plabels = self.proto_layer.prototype_labels
return distances, prototype_labels dis = euclidean_distance(x, protos)
return dis, plabels
# Build the GLVQ model # Build the GLVQ model
@@ -51,46 +54,43 @@ x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train) y_in = torch.Tensor(y_train)
# Training loop # Training loop
TITLE = "Prototype Visualization" title = "Prototype Visualization"
fig = plt.figure(TITLE) fig = plt.figure(title)
for epoch in range(70): for epoch in range(70):
# Compute loss # Compute loss
distances, prototype_labels = model(x_in) dis, plabels = model(x_in)
loss = criterion([distances, prototype_labels], y_in) loss = criterion([dis, plabels], y_in)
# Compute Accuracy
with torch.no_grad(): with torch.no_grad():
predictions = wtac(distances, prototype_labels) pred = wtac(dis, plabels)
correct = predictions.eq(y_in.view_as(predictions)).sum().item() correct = pred.eq(y_in.view_as(pred)).sum().item()
acc = 100.0 * correct / len(x_train) acc = 100.0 * correct / len(x_train)
print( print(
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%" f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
) )
# Optimizer step # Take a gradient descent step
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# Get the prototypes form the model # Get the prototypes form the model
prototypes = model.proto_layer.components.numpy() protos = model.proto_layer.prototypes.data.numpy()
if np.isnan(np.sum(prototypes)): if np.isnan(np.sum(protos)):
print("Stopping training because of `nan` in prototypes.") print("Stopping training because of `nan` in prototypes.")
break break
# Visualize the data and the prototypes # Visualize the data and the prototypes
ax = fig.gca() ax = fig.gca()
ax.cla() ax.cla()
ax.set_title(TITLE) ax.set_title(title)
ax.set_xlabel("Data dimension 1") ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2") ax.set_ylabel("Data dimension 2")
cmap = "viridis" cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter( ax.scatter(
prototypes[:, 0], protos[:, 0],
prototypes[:, 1], protos[:, 1],
c=prototype_labels, c=plabels,
cmap=cmap, cmap=cmap,
edgecolor="k", edgecolor="k",
marker="D", marker="D",
@@ -98,7 +98,7 @@ for epoch in range(70):
) )
# Paint decision regions # Paint decision regions
x = np.vstack((x_train, prototypes)) x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50), xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
@@ -108,7 +108,7 @@ for epoch in range(70):
torch_input = torch.Tensor(mesh_input) torch_input = torch.Tensor(mesh_input)
d = model(torch_input)[0] d = model(torch_input)[0]
w_indices = torch.argmin(d, dim=1) w_indices = torch.argmin(d, dim=1)
y_pred = torch.index_select(prototype_labels, 0, w_indices) y_pred = torch.index_select(plabels, 0, w_indices)
y_pred = y_pred.reshape(xx.shape) y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions # Plot voronoi regions

View File

@@ -2,12 +2,13 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
from prototorch.components import LabeledComponents, StratifiedMeanInitializer from torch.utils.data import DataLoader
from prototorch.datasets.tecator import Tecator from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed from prototorch.functions.distances import sed
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import get_legend_handles from prototorch.utils.colors import get_legend_handles
from torch.utils.data import DataLoader
# Prepare the dataset and dataloader # Prepare the dataset and dataloader
train_data = Tecator(root="./artifacts", train=True) train_data = Tecator(root="./artifacts", train=True)
@@ -18,22 +19,22 @@ class Model(torch.nn.Module):
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""GMLVQ model as a siamese network.""" """GMLVQ model as a siamese network."""
super().__init__() super().__init__()
prototype_initializer = StratifiedMeanInitializer(train_loader) x, y = train_data.data, train_data.targets
prototype_distribution = {"num_classes": 2, "prototypes_per_class": 2} self.p1 = Prototypes1D(
input_dim=100,
self.proto_layer = LabeledComponents( prototypes_per_class=2,
prototype_distribution, nclasses=2,
prototype_initializer, prototype_initializer="stratified_random",
data=[x, y],
) )
self.omega = torch.nn.Linear(in_features=100, self.omega = torch.nn.Linear(in_features=100,
out_features=100, out_features=100,
bias=False) bias=False)
torch.nn.init.eye_(self.omega.weight) torch.nn.init.eye_(self.omega.weight)
def forward(self, x): def forward(self, x):
protos = self.proto_layer.components protos = self.p1.prototypes
plabels = self.proto_layer.component_labels plabels = self.p1.prototype_labels
# Process `x` and `protos` through `omega` # Process `x` and `protos` through `omega`
x_map = self.omega(x) x_map = self.omega(x)
@@ -85,8 +86,8 @@ im = ax.imshow(omega.dot(omega.T), cmap="viridis")
plt.show() plt.show()
# Get the prototypes form the model # Get the prototypes form the model
protos = model.proto_layer.components.numpy() protos = model.p1.prototypes.data.numpy()
plabels = model.proto_layer.component_labels.numpy() plabels = model.p1.prototype_labels
# Visualize the prototypes # Visualize the prototypes
title = "Tecator Prototypes" title = "Tecator Prototypes"

View File

@@ -12,19 +12,20 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from torchvision import transforms
from prototorch.functions.helper import calculate_prototype_accuracy from prototorch.functions.helper import calculate_prototype_accuracy
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.modules.models import GTLVQ from prototorch.modules.models import GTLVQ
from torchvision import transforms
# Parameters and options # Parameters and options
num_epochs = 50 n_epochs = 50
batch_size_train = 64 batch_size_train = 64
batch_size_test = 1000 batch_size_test = 1000
learning_rate = 0.1 learning_rate = 0.1
momentum = 0.5 momentum = 0.5
log_interval = 10 log_interval = 10
cuda = "cuda:0" cuda = "cuda:1"
random_seed = 1 random_seed = 1
device = torch.device(cuda if torch.cuda.is_available() else "cpu") device = torch.device(cuda if torch.cuda.is_available() else "cpu")
@@ -140,14 +141,14 @@ optimizer = torch.optim.Adam(
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10) criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
# Training loop # Training loop
for epoch in range(num_epochs): for epoch in range(n_epochs):
for batch_idx, (x_train, y_train) in enumerate(train_loader): for batch_idx, (x_train, y_train) in enumerate(train_loader):
model.train() model.train()
x_train, y_train = x_train.to(device), y_train.to(device) x_train, y_train = x_train.to(device), y_train.to(device)
optimizer.zero_grad() optimizer.zero_grad()
distances = model(x_train) distances = model(x_train)
plabels = model.gtlvq.cls.component_labels.to(device) plabels = model.gtlvq.cls.prototype_labels.to(device)
# Compute loss. # Compute loss.
loss = criterion([distances, plabels], y_train) loss = criterion([distances, plabels], y_train)
@@ -160,7 +161,7 @@ for epoch in range(num_epochs):
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
acc = calculate_prototype_accuracy(distances, y_train, plabels) acc = calculate_prototype_accuracy(distances, y_train, plabels)
print( print(
f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \ f"Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
Train Acc: {acc.item():02.02f}") Train Acc: {acc.item():02.02f}")
# Test # Test

View File

@@ -3,13 +3,15 @@
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.competitions import stratified_min
from prototorch.functions.distances import lomega_distance
from prototorch.modules.losses import GLVQLoss
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from prototorch.functions.competitions import stratified_min
from prototorch.functions.distances import lomega_distance
from prototorch.functions.init import eye_
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D
# Prepare training data # Prepare training data
x_train, y_train = load_iris(True) x_train, y_train = load_iris(True)
x_train = x_train[:, [0, 2]] x_train = x_train[:, [0, 2]]
@@ -20,19 +22,19 @@ class Model(torch.nn.Module):
def __init__(self): def __init__(self):
"""Local-GMLVQ model.""" """Local-GMLVQ model."""
super().__init__() super().__init__()
self.p1 = Prototypes1D(
prototype_initializer = StratifiedMeanInitializer([x_train, y_train]) input_dim=2,
prototype_distribution = [1, 2, 2] prototype_distribution=[1, 2, 2],
self.proto_layer = LabeledComponents( prototype_initializer="stratified_random",
prototype_distribution, data=[x_train, y_train],
prototype_initializer,
) )
omegas = torch.zeros(5, 2, 2)
omegas = torch.eye(2, 2).repeat(5, 1, 1)
self.omegas = torch.nn.Parameter(omegas) self.omegas = torch.nn.Parameter(omegas)
eye_(self.omegas)
def forward(self, x): def forward(self, x):
protos, plabels = self.proto_layer() protos = self.p1.prototypes
plabels = self.p1.prototype_labels
omegas = self.omegas omegas = self.omegas
dis = lomega_distance(x, protos, omegas) dis = lomega_distance(x, protos, omegas)
return dis, plabels return dis, plabels
@@ -67,7 +69,7 @@ for epoch in range(100):
optimizer.step() optimizer.step()
# Get the prototypes form the model # Get the prototypes form the model
protos = model.proto_layer.components.numpy() protos = model.p1.prototypes.data.numpy()
# Visualize the data and the prototypes # Visualize the data and the prototypes
ax = fig.gca() ax = fig.gca()

View File

@@ -1,24 +1,21 @@
"""ProtoTorch package.""" """ProtoTorch package."""
import pkgutil
import pkg_resources
from . import components, datasets, functions, modules, utils
from .datasets import *
# Core Setup # Core Setup
__version__ = "0.5.0" __version__ = "0.4.2"
__all_core__ = [ __all_core__ = [
"datasets", "datasets",
"functions", "functions",
"modules", "modules",
"components",
"utils",
] ]
from .datasets import *
# Plugin Loader # Plugin Loader
import pkgutil
import pkg_resources
__path__ = pkgutil.extend_path(__path__, __name__) __path__ = pkgutil.extend_path(__path__, __name__)

View File

@@ -1,37 +1,36 @@
"""ProtoTorch components modules.""" """ProtoTorch components modules."""
import warnings import warnings
from typing import Tuple
import torch import torch
from prototorch.components.initializers import (ClassAwareInitializer, from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer, ComponentsInitializer,
CustomLabelsInitializer,
EqualLabelsInitializer, EqualLabelsInitializer,
UnequalLabelsInitializer, UnequalLabelsInitializer,
ZeroReasoningsInitializer) ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
class Components(torch.nn.Module): class Components(torch.nn.Module):
"""Components is a set of learnable Tensors.""" """Components is a set of learnable Tensors."""
def __init__(self, def __init__(self,
num_components=None, number_of_components=None,
initializer=None, initializer=None,
*, *,
initialized_components=None): initialized_components=None,
dtype=torch.float32):
super().__init__() super().__init__()
self.num_components = num_components
# Ignore all initialization settings if initialized_components is given. # Ignore all initialization settings if initialized_components is given.
if initialized_components is not None: if initialized_components is not None:
self.register_parameter("_components", self._components = Parameter(initialized_components)
Parameter(initialized_components)) if number_of_components is not None or initializer is not None:
if num_components is not None or initializer is not None:
wmsg = "Arguments ignored while initializing Components" wmsg = "Arguments ignored while initializing Components"
warnings.warn(wmsg) warnings.warn(wmsg)
else: else:
self._initialize_components(initializer) self._initialize_components(number_of_components, initializer)
def _precheck_initializer(self, initializer): def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer): if not isinstance(initializer, ComponentsInitializer):
@@ -40,15 +39,15 @@ class Components(torch.nn.Module):
f"You have provided: {initializer=} instead." f"You have provided: {initializer=} instead."
raise TypeError(emsg) raise TypeError(emsg)
def _initialize_components(self, initializer): def _initialize_components(self, number_of_components, initializer):
self._precheck_initializer(initializer) self._precheck_initializer(initializer)
_components = initializer.generate(self.num_components) self._components = Parameter(
self.register_parameter("_components", Parameter(_components)) initializer.generate(number_of_components))
@property @property
def components(self): def components(self):
"""Tensor containing the component tensors.""" """Tensor containing the component tensors."""
return self._components.detach() return self._components.detach().cpu()
def forward(self): def forward(self):
return self._components return self._components
@@ -68,44 +67,36 @@ class LabeledComponents(Components):
*, *,
initialized_components=None): initialized_components=None):
if initialized_components is not None: if initialized_components is not None:
components, component_labels = initialized_components super().__init__(initialized_components=initialized_components[0])
super().__init__(initialized_components=components) self._labels = initialized_components[1]
self._labels = component_labels
else: else:
_labels = self._initialize_labels(distribution) self._initialize_labels(distribution)
super().__init__(len(_labels), initializer=initializer) super().__init__(number_of_components=len(self._labels),
self.register_buffer("_labels", _labels) initializer=initializer)
def _initialize_components(self, initializer): def _initialize_components(self, number_of_components, initializer):
if isinstance(initializer, ClassAwareInitializer): if isinstance(initializer, ClassAwareInitializer):
self._precheck_initializer(initializer) self._precheck_initializer(initializer)
_components = initializer.generate(self.num_components, self._components = Parameter(
self.distribution) initializer.generate(number_of_components, self.distribution))
self.register_parameter("_components", Parameter(_components))
else: else:
super()._initialize_components(initializer) super()._initialize_components(self, number_of_components,
initializer)
def _initialize_labels(self, distribution): def _initialize_labels(self, distribution):
if type(distribution) == dict: if type(distribution) == tuple:
if "num_classes" in distribution.keys():
labels = EqualLabelsInitializer(
distribution["num_classes"],
distribution["prototypes_per_class"])
else:
labels = CustomLabelsInitializer(distribution)
elif type(distribution) == tuple:
num_classes, prototypes_per_class = distribution num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class) labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif type(distribution) == list: elif type(distribution) == list:
labels = UnequalLabelsInitializer(distribution) labels = UnequalLabelsInitializer(distribution)
self.distribution = labels.distribution self.distribution = labels.distribution
return labels.generate() self._labels = labels.generate()
@property @property
def component_labels(self): def component_labels(self):
"""Tensor containing the component tensors.""" """Tensor containing the component tensors."""
return self._labels.detach() return self._labels.detach().cpu()
def forward(self): def forward(self):
return super().forward(), self._labels return super().forward(), self._labels
@@ -132,21 +123,20 @@ class ReasoningComponents(Components):
*, *,
initialized_components=None): initialized_components=None):
if initialized_components is not None: if initialized_components is not None:
components, reasonings = initialized_components super().__init__(initialized_components=initialized_components[0])
self._reasonings = initialized_components[1]
super().__init__(initialized_components=components)
self.register_parameter("_reasonings", reasonings)
else: else:
self._initialize_reasonings(reasonings) self._initialize_reasonings(reasonings)
super().__init__(len(self._reasonings), initializer=initializer) super().__init__(number_of_components=len(self._reasonings),
initializer=initializer)
def _initialize_reasonings(self, reasonings): def _initialize_reasonings(self, reasonings):
if type(reasonings) == tuple: if type(reasonings) == tuple:
num_classes, num_components = reasonings num_classes, number_of_components = reasonings
reasonings = ZeroReasoningsInitializer(num_classes, num_components) reasonings = ZeroReasoningsInitializer(num_classes,
number_of_components)
_reasonings = reasonings.generate() self._reasonings = reasonings.generate()
self.register_parameter("_reasonings", _reasonings)
@property @property
def reasonings(self): def reasonings(self):
@@ -155,7 +145,7 @@ class ReasoningComponents(Components):
Dimension NxCx2 Dimension NxCx2
""" """
return self._reasonings.detach() return self._reasonings.detach().cpu()
def forward(self): def forward(self):
return super().forward(), self._reasonings return super().forward(), self._reasonings

View File

@@ -7,18 +7,12 @@ import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
def parse_data_arg(data_arg): def parse_init_arg(arg):
if isinstance(data_arg, Dataset): if isinstance(arg, Dataset):
data_arg = DataLoader(data_arg, batch_size=len(data_arg)) data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
# data = data.view(len(arg), -1) # flatten
if isinstance(data_arg, DataLoader):
data = torch.tensor([])
labels = torch.tensor([])
for x, y in data_arg:
data = torch.cat([data, x])
labels = torch.cat([labels, y])
else: else:
data, labels = data_arg data, labels = arg
if not isinstance(data, torch.Tensor): if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}." wmsg = f"Converting data to {torch.Tensor}."
warnings.warn(wmsg) warnings.warn(wmsg)
@@ -69,19 +63,19 @@ class UniformInitializer(DimensionAwareInitializer):
return torch.ones(gen_dims).uniform_(self.min, self.max) return torch.ones(gen_dims).uniform_(self.min, self.max)
class DataAwareInitializer(ComponentsInitializer): class PositionAwareInitializer(ComponentsInitializer):
def __init__(self, data): def __init__(self, positions):
super().__init__() super().__init__()
self.data = data self.data = positions
class SelectionInitializer(DataAwareInitializer): class SelectionInitializer(PositionAwareInitializer):
def generate(self, length): def generate(self, length):
indices = torch.LongTensor(length).random_(0, len(self.data)) indices = torch.LongTensor(length).random_(0, len(self.data))
return self.data[indices] return self.data[indices]
class MeanInitializer(DataAwareInitializer): class MeanInitializer(PositionAwareInitializer):
def generate(self, length): def generate(self, length):
mean = torch.mean(self.data, dim=0) mean = torch.mean(self.data, dim=0)
repeat_dim = [length] + [1] * len(mean.shape) repeat_dim = [length] + [1] * len(mean.shape)
@@ -89,14 +83,12 @@ class MeanInitializer(DataAwareInitializer):
class ClassAwareInitializer(ComponentsInitializer): class ClassAwareInitializer(ComponentsInitializer):
def __init__(self, data, transform=torch.nn.Identity()): def __init__(self, arg):
super().__init__() super().__init__()
data, labels = parse_data_arg(data) data, labels = parse_init_arg(arg)
self.data = data self.data = data
self.labels = labels self.labels = labels
self.transform = transform
self.clabels = torch.unique(self.labels) self.clabels = torch.unique(self.labels)
self.num_classes = len(self.clabels) self.num_classes = len(self.clabels)
@@ -104,24 +96,15 @@ class ClassAwareInitializer(ComponentsInitializer):
if not dist: if not dist:
per_class = length // self.num_classes per_class = length // self.num_classes
dist = self.num_classes * [per_class] dist = self.num_classes * [per_class]
if type(dist) == dict:
dist = dist.values()
samples_list = [ samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist) init.generate(n) for init, n in zip(self.initializers, dist)
] ]
out = torch.vstack(samples_list) return torch.vstack(samples_list)
with torch.no_grad():
out = self.transform(out)
return out
def __del__(self):
del self.data
del self.labels
class StratifiedMeanInitializer(ClassAwareInitializer): class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, data, **kwargs): def __init__(self, arg):
super().__init__(data, **kwargs) super().__init__(arg)
self.initializers = [] self.initializers = []
for clabel in self.clabels: for clabel in self.clabels:
@@ -135,8 +118,8 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
class StratifiedSelectionInitializer(ClassAwareInitializer): class StratifiedSelectionInitializer(ClassAwareInitializer):
def __init__(self, data, noise=None, **kwargs): def __init__(self, arg, *, noise=None):
super().__init__(data, **kwargs) super().__init__(arg)
self.noise = noise self.noise = noise
self.initializers = [] self.initializers = []
@@ -145,10 +128,7 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
class_initializer = SelectionInitializer(class_data) class_initializer = SelectionInitializer(class_data)
self.initializers.append(class_initializer) self.initializers.append(class_initializer)
def add_noise_v1(self, x): def add_noise(self, x):
return x + self.noise
def add_noise_v2(self, x):
"""Shifts some dimensions of the data randomly.""" """Shifts some dimensions of the data randomly."""
n1 = torch.rand_like(x) n1 = torch.rand_like(x)
n2 = torch.rand_like(x) n2 = torch.rand_like(x)
@@ -158,7 +138,8 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
def generate(self, length, dist=[]): def generate(self, length, dist=[]):
samples = self._get_samples_from_initializer(length, dist) samples = self._get_samples_from_initializer(length, dist)
if self.noise is not None: if self.noise is not None:
samples = self.add_noise_v1(samples) # samples = self.add_noise(samples)
samples = samples + self.noise
return samples return samples
@@ -176,13 +157,10 @@ class UnequalLabelsInitializer(LabelsInitializer):
def distribution(self): def distribution(self):
return self.dist return self.dist
def generate(self, clabels=None, dist=None): def generate(self):
if not clabels: clabels = range(len(self.dist))
clabels = range(len(self.dist)) labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
if not dist: return torch.tensor(labels)
dist = self.dist
labels = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
return torch.LongTensor(labels)
class EqualLabelsInitializer(LabelsInitializer): class EqualLabelsInitializer(LabelsInitializer):
@@ -198,13 +176,6 @@ class EqualLabelsInitializer(LabelsInitializer):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
class CustomLabelsInitializer(UnequalLabelsInitializer):
def generate(self):
clabels = list(self.dist.keys())
dist = list(self.dist.values())
return super().generate(clabels, dist)
# Reasonings # Reasonings
class ReasoningsInitializer: class ReasoningsInitializer:
def generate(self, length): def generate(self, length):
@@ -224,5 +195,3 @@ class ZeroReasoningsInitializer(ReasoningsInitializer):
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
SMI = StratifiedMeanInitializer SMI = StratifiedMeanInitializer
Random = RandomInitializer = UniformInitializer Random = RandomInitializer = UniformInitializer
Zeros = ZerosInitializer
Ones = OnesInitializer

View File

@@ -1,8 +1,11 @@
"""ProtoTorch datasets.""" """ProtoTorch datasets."""
from .abstract import NumpyDataset from .abstract import NumpyDataset
from .iris import Iris
from .spiral import Spiral from .spiral import Spiral
from .tecator import Tecator from .tecator import Tecator
__all__ = ['Iris', 'Spiral', 'Tecator'] __all__ = [
"NumpyDataset",
"Spiral",
"Tecator",
]

View File

@@ -14,10 +14,8 @@ import torch
class NumpyDataset(torch.utils.data.TensorDataset): class NumpyDataset(torch.utils.data.TensorDataset):
"""Create a PyTorch TensorDataset from NumPy arrays.""" """Create a PyTorch TensorDataset from NumPy arrays."""
def __init__(self, data, targets): def __init__(self, *arrays):
self.data = data tensors = [torch.Tensor(arr) for arr in arrays]
self.targets = targets
tensors = [torch.Tensor(data), torch.Tensor(targets)]
super().__init__(*tensors) super().__init__(*tensors)

View File

@@ -1,40 +0,0 @@
"""Thin wrapper for the Iris classification dataset from sklearn.
URL:
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html
"""
from typing import Sequence
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import load_iris
class Iris(NumpyDataset):
"""
Iris Dataset by Ronald Fisher introduced in 1936.
The dataset contains four measurements from flowers of three species of iris.
.. list-table:: Iris
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 4
- 3
- 150
- 0
- 0
:param dims: select a subset of dimensions
"""
def __init__(self, dims: Sequence[int] = None):
x, y = load_iris(return_X_y=True)
if dims:
x = x[:, dims]
super().__init__(x, y)

View File

@@ -4,22 +4,18 @@ import numpy as np
import torch import torch
def make_spiral(num_samples=500, noise=0.3): def make_spiral(n_samples=500, noise=0.3):
"""Generates the Spiral Dataset.
For use in Prototorch use `prototorch.datasets.Spiral` instead.
"""
def get_samples(n, delta_t): def get_samples(n, delta_t):
points = [] points = []
for i in range(n): for i in range(n):
r = i / num_samples * 5 r = i / n_samples * 5
t = 1.75 * i / n * 2 * np.pi + delta_t t = 1.75 * i / n * 2 * np.pi + delta_t
x = r * np.sin(t) + np.random.rand(1) * noise x = r * np.sin(t) + np.random.rand(1) * noise
y = r * np.cos(t) + np.random.rand(1) * noise y = r * np.cos(t) + np.random.rand(1) * noise
points.append([x, y]) points.append([x, y])
return points return points
n = num_samples // 2 n = n_samples // 2
positive = get_samples(n=n, delta_t=0) positive = get_samples(n=n, delta_t=0)
negative = get_samples(n=n, delta_t=np.pi) negative = get_samples(n=n, delta_t=np.pi)
x = np.concatenate( x = np.concatenate(
@@ -31,27 +27,7 @@ def make_spiral(num_samples=500, noise=0.3):
class Spiral(torch.utils.data.TensorDataset): class Spiral(torch.utils.data.TensorDataset):
"""Spiral dataset for binary classification. """Spiral dataset for binary classification."""
def __init__(self, n_samples=500, noise=0.3):
This datasets consists of two spirals of two different classes. x, y = make_spiral(n_samples, noise)
.. list-table:: Spiral
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 2
- 2
- num_samples
- 0
- 0
:param num_samples: number of random samples
:param noise: noise added to the spirals
"""
def __init__(self, num_samples: int = 500, noise: float = 0.3):
x, y = make_spiral(num_samples, noise)
super().__init__(torch.Tensor(x), torch.LongTensor(y)) super().__init__(torch.Tensor(x), torch.LongTensor(y))

View File

@@ -40,29 +40,15 @@ import os
import numpy as np import numpy as np
import torch import torch
from prototorch.datasets.abstract import ProtoDataset
from torchvision.datasets.utils import download_file_from_google_drive from torchvision.datasets.utils import download_file_from_google_drive
from prototorch.datasets.abstract import ProtoDataset
class Tecator(ProtoDataset): class Tecator(ProtoDataset):
""" """
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__ for classification. `Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__
for classification.
The dataset contains wavelength measurements of meat.
.. list-table:: Tecator
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 100
- 2
- 129
- 43
- 43
""" """
_resources = [ _resources = [

View File

@@ -3,14 +3,15 @@
import torch import torch
# @torch.jit.script
def stratified_min(distances, labels): def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0) clabels = torch.unique(labels, dim=0)
num_classes = clabels.size()[0] nclasses = clabels.size()[0]
if distances.size()[1] == num_classes: if distances.size()[1] == nclasses:
# skip if only one prototype per class # skip if only one prototype per class
return distances return distances
batch_size = distances.size()[0] batch_size = distances.size()[0]
winning_distances = torch.zeros(num_classes, batch_size) winning_distances = torch.zeros(nclasses, batch_size)
inf = torch.full_like(distances.T, fill_value=float("inf")) inf = torch.full_like(distances.T, fill_value=float("inf"))
# distances_to_wpluses = torch.where(matcher, distances, inf) # distances_to_wpluses = torch.where(matcher, distances, inf)
for i, cl in enumerate(clabels): for i, cl in enumerate(clabels):
@@ -18,7 +19,7 @@ def stratified_min(distances, labels):
matcher = torch.eq(labels.unsqueeze(dim=1), cl) matcher = torch.eq(labels.unsqueeze(dim=1), cl)
if labels.ndim == 2: if labels.ndim == 2:
# if the labels are one-hot vectors # if the labels are one-hot vectors
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
cdists = torch.where(matcher, distances.T, inf).T cdists = torch.where(matcher, distances.T, inf).T
winning_distances[i] = torch.min(cdists, dim=1, winning_distances[i] = torch.min(cdists, dim=1,
keepdim=True).values.squeeze() keepdim=True).values.squeeze()
@@ -30,15 +31,15 @@ def stratified_min(distances, labels):
return winning_distances.T # return with `batch_size` first return winning_distances.T # return with `batch_size` first
# @torch.jit.script
def wtac(distances, labels): def wtac(distances, labels):
winning_indices = torch.min(distances, dim=1).indices winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze() winning_labels = labels[winning_indices].squeeze()
return winning_labels return winning_labels
def knnc(distances, labels, k=1): # @torch.jit.script
winning_indices = torch.topk(-distances, k=k, dim=1).indices def knnc(distances, labels, k):
# winning_labels = torch.mode(labels[winning_indices].squeeze(), winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
# dim=1).values winning_labels = labels[winning_indices].squeeze()
winning_labels = torch.mode(labels[winning_indices], dim=1).values
return winning_labels return winning_labels

View File

@@ -2,8 +2,12 @@
import numpy as np import numpy as np
import torch import torch
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
equal_int_shape, get_flat) from prototorch.functions.helper import (
_check_shapes,
_int_and_mixed_shape,
equal_int_shape,
)
def squared_euclidean_distance(x, y): def squared_euclidean_distance(x, y):
@@ -11,10 +15,12 @@ def squared_euclidean_distance(x, y):
Compute :math:`{\langle \bm x - \bm y \rangle}_2` Compute :math:`{\langle \bm x - \bm y \rangle}_2`
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
**Alias:** **Alias:**
``prototorch.functions.distances.sed`` ``prototorch.functions.distances.sed``
""" """
x, y = get_flat(x, y)
expanded_x = x.unsqueeze(dim=1) expanded_x = x.unsqueeze(dim=1)
batchwise_difference = y - expanded_x batchwise_difference = y - expanded_x
differences_raised = torch.pow(batchwise_difference, 2) differences_raised = torch.pow(batchwise_difference, 2)
@@ -27,17 +33,18 @@ def euclidean_distance(x, y):
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}` Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
:param `torch.tensor` x: Input Tensor of shape :math:`X \times N`
:param `torch.tensor` y: Input Tensor of shape :math:`Y \times N`
:returns: Distance Tensor of shape :math:`X \times Y` :returns: Distance Tensor of shape :math:`X \times Y`
:rtype: `torch.tensor` :rtype: `torch.tensor`
""" """
x, y = get_flat(x, y)
distances_raised = squared_euclidean_distance(x, y) distances_raised = squared_euclidean_distance(x, y)
distances = torch.sqrt(distances_raised) distances = torch.sqrt(distances_raised)
return distances return distances
def euclidean_distance_v2(x, y): def euclidean_distance_v2(x, y):
x, y = get_flat(x, y)
diff = y - x.unsqueeze(1) diff = y - x.unsqueeze(1)
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt() pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the # Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
@@ -58,9 +65,10 @@ def lpnorm_distance(x, y, p):
Calls ``torch.cdist`` Calls ``torch.cdist``
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
:param p: p parameter of the lp norm :param p: p parameter of the lp norm
""" """
x, y = get_flat(x, y)
distances = torch.cdist(x, y, p=p) distances = torch.cdist(x, y, p=p)
return distances return distances
@@ -70,9 +78,10 @@ def omega_distance(x, y, omega):
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p` Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
:param `torch.tensor` omega: Two dimensional matrix :param `torch.tensor` omega: Two dimensional matrix
""" """
x, y = get_flat(x, y)
projected_x = x @ omega projected_x = x @ omega
projected_y = y @ omega projected_y = y @ omega
distances = squared_euclidean_distance(projected_x, projected_y) distances = squared_euclidean_distance(projected_x, projected_y)
@@ -84,9 +93,10 @@ def lomega_distance(x, y, omegas):
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p` Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
:param `torch.tensor` x: Two dimensional vector
:param `torch.tensor` y: Two dimensional vector
:param `torch.tensor` omegas: Three dimensional matrix :param `torch.tensor` omegas: Three dimensional matrix
""" """
x, y = get_flat(x, y)
projected_x = x @ omegas projected_x = x @ omegas
projected_y = torch.diagonal(y @ omegas).T projected_y = torch.diagonal(y @ omegas).T
expanded_y = torch.unsqueeze(projected_y, dim=1) expanded_y = torch.unsqueeze(projected_y, dim=1)
@@ -254,5 +264,86 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
return diss.permute([1, 0, 2]).squeeze(-1) return diss.permute([1, 0, 2]).squeeze(-1)
class KernelDistance:
r"""Kernel Distance
Distance based on a kernel function.
"""
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
def __call__(self, x_batch: torch.Tensor, y_batch: torch.Tensor):
return self._single_call(x_batch, y_batch)
def _single_call(self, x, y):
remove_dims = []
if len(x.shape) == 1:
x = x.unsqueeze(0)
remove_dims.append(0)
if len(y.shape) == 1:
y = y.unsqueeze(0)
remove_dims.append(-1)
output = self.kernel_fn(x, x).diag().unsqueeze(1) - 2 * self.kernel_fn(
x, y) + self.kernel_fn(y, y).diag()
for dim in remove_dims:
output.squeeze_(dim)
return torch.sqrt(output)
class BatchKernelDistance:
r"""Kernel Distance
Distance based on a kernel function.
"""
def __init__(self, kernel_fn):
self.kernel_fn = kernel_fn
def __call__(self, x_batch: torch.Tensor, y_batch: torch.Tensor):
remove_dims = 0
# Extend Single inputs
if len(x_batch.shape) == 1:
x_batch = x_batch.unsqueeze(0)
remove_dims += 1
if len(y_batch.shape) == 1:
y_batch = y_batch.unsqueeze(0)
remove_dims += 1
# Loop over batches
output = torch.FloatTensor(len(x_batch), len(y_batch))
for i, x in enumerate(x_batch):
for j, y in enumerate(y_batch):
output[i][j] = self._single_call(x, y)
for _ in range(remove_dims):
output.squeeze_(0)
return output
def _single_call(self, x, y):
kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y)
squared_distance = kappa_xx - 2 * kappa_xy + kappa_yy
return torch.sqrt(squared_distance)
class SquaredKernelDistance(KernelDistance):
r"""Squared Kernel Distance
Kernel distance without final squareroot.
"""
def single_call(self, x, y):
kappa_xx = self.kernel_fn(x, x)
kappa_xy = self.kernel_fn(x, y)
kappa_yy = self.kernel_fn(y, y)
return kappa_xx - 2 * kappa_xy + kappa_yy
# Aliases # Aliases
sed = squared_euclidean_distance sed = squared_euclidean_distance

View File

@@ -1,11 +1,6 @@
import torch import torch
def get_flat(*args):
rv = [x.view(x.size(0), -1) for x in args]
return rv
def calculate_prototype_accuracy(y_pred, y_true, plabels): def calculate_prototype_accuracy(y_pred, y_true, plabels):
"""Computes the accuracy of a prototype based model. """Computes the accuracy of a prototype based model.
via Winner-Takes-All rule. via Winner-Takes-All rule.

View File

@@ -15,59 +15,59 @@ def register_initializer(function):
def labels_from(distribution, one_hot=True): def labels_from(distribution, one_hot=True):
"""Takes a distribution tensor and returns a labels tensor.""" """Takes a distribution tensor and returns a labels tensor."""
num_classes = distribution.shape[0] nclasses = distribution.shape[0]
llist = [[i] * n for i, n in zip(range(num_classes), distribution)] llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
# labels = [l for cl in llist for l in cl] # flatten the list of lists # labels = [l for cl in llist for l in cl] # flatten the list of lists
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
plabels = torch.tensor(flat_llist, requires_grad=False) plabels = torch.tensor(flat_llist, requires_grad=False)
if one_hot: if one_hot:
return torch.eye(num_classes)[plabels] return torch.eye(nclasses)[plabels]
return plabels return plabels
@register_initializer @register_initializer
def ones(x_train, y_train, prototype_distribution, one_hot=True): def ones(x_train, y_train, prototype_distribution, one_hot=True):
num_protos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
protos = torch.ones(num_protos, *x_train.shape[1:]) protos = torch.ones(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def zeros(x_train, y_train, prototype_distribution, one_hot=True): def zeros(x_train, y_train, prototype_distribution, one_hot=True):
num_protos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
protos = torch.zeros(num_protos, *x_train.shape[1:]) protos = torch.zeros(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def rand(x_train, y_train, prototype_distribution, one_hot=True): def rand(x_train, y_train, prototype_distribution, one_hot=True):
num_protos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
protos = torch.rand(num_protos, *x_train.shape[1:]) protos = torch.rand(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def randn(x_train, y_train, prototype_distribution, one_hot=True): def randn(x_train, y_train, prototype_distribution, one_hot=True):
num_protos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
protos = torch.randn(num_protos, *x_train.shape[1:]) protos = torch.randn(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True): def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
num_protos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
pdim = x_train.shape[1] pdim = x_train.shape[1]
protos = torch.empty(num_protos, pdim) protos = torch.empty(nprotos, pdim)
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
for i, label in enumerate(plabels): for i, label in enumerate(plabels):
matcher = torch.eq(label.unsqueeze(dim=0), y_train) matcher = torch.eq(label.unsqueeze(dim=0), y_train)
if one_hot: if one_hot:
num_classes = y_train.size()[1] nclasses = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
xl = x_train[matcher] xl = x_train[matcher]
mean_xl = torch.mean(xl, dim=0) mean_xl = torch.mean(xl, dim=0)
protos[i] = mean_xl protos[i] = mean_xl
@@ -81,15 +81,15 @@ def stratified_random(x_train,
prototype_distribution, prototype_distribution,
one_hot=True, one_hot=True,
epsilon=1e-7): epsilon=1e-7):
num_protos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
pdim = x_train.shape[1] pdim = x_train.shape[1]
protos = torch.empty(num_protos, pdim) protos = torch.empty(nprotos, pdim)
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
for i, label in enumerate(plabels): for i, label in enumerate(plabels):
matcher = torch.eq(label.unsqueeze(dim=0), y_train) matcher = torch.eq(label.unsqueeze(dim=0), y_train)
if one_hot: if one_hot:
num_classes = y_train.size()[1] nclasses = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
xl = x_train[matcher] xl = x_train[matcher]
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1) rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
random_xl = xl[rand_index] random_xl = xl[rand_index]

View File

@@ -0,0 +1,28 @@
"""
Experimental Kernels
"""
import torch
class ExplicitKernel:
def __init__(self, projection=torch.nn.Identity()):
self.projection = projection
def __call__(self, x, y):
return self.projection(x) @ self.projection(y).T
class RadialBasisFunctionKernel:
def __init__(self, sigma) -> None:
self.s2 = sigma * sigma
def __call__(self, x, y):
remove_dim = False
if len(x.shape) > 1:
x = x.unsqueeze(1)
remove_dim = True
output = torch.exp(-torch.sum((x - y)**2, dim=-1) / (2 * self.s2))
if remove_dim:
output = output.squeeze(1)
return output

View File

@@ -8,12 +8,12 @@ def _get_matcher(targets, labels):
matcher = torch.eq(targets.unsqueeze(dim=1), labels) matcher = torch.eq(targets.unsqueeze(dim=1), labels)
if labels.ndim == 2: if labels.ndim == 2:
# if the labels are one-hot vectors # if the labels are one-hot vectors
num_classes = targets.size()[1] nclasses = targets.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes) matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
return matcher return matcher
def _get_dp_dm(distances, targets, plabels, with_indices=False): def _get_dp_dm(distances, targets, plabels):
"""Returns the d+ and d- values for a batch of distances.""" """Returns the d+ and d- values for a batch of distances."""
matcher = _get_matcher(targets, plabels) matcher = _get_matcher(targets, plabels)
not_matcher = torch.bitwise_not(matcher) not_matcher = torch.bitwise_not(matcher)
@@ -21,11 +21,9 @@ def _get_dp_dm(distances, targets, plabels, with_indices=False):
inf = torch.full_like(distances, fill_value=float("inf")) inf = torch.full_like(distances, fill_value=float("inf"))
d_matching = torch.where(matcher, distances, inf) d_matching = torch.where(matcher, distances, inf)
d_unmatching = torch.where(not_matcher, distances, inf) d_unmatching = torch.where(not_matcher, distances, inf)
dp = torch.min(d_matching, dim=-1, keepdim=True) dp = torch.min(d_matching, dim=1, keepdim=True).values
dm = torch.min(d_unmatching, dim=-1, keepdim=True) dm = torch.min(d_unmatching, dim=1, keepdim=True).values
if with_indices: return dp, dm
return dp, dm
return dp.values, dm.values
def glvq_loss(distances, target_labels, prototype_labels): def glvq_loss(distances, target_labels, prototype_labels):
@@ -49,11 +47,10 @@ def lvq1_loss(distances, target_labels, prototype_labels):
def lvq21_loss(distances, target_labels, prototype_labels): def lvq21_loss(distances, target_labels, prototype_labels):
"""LVQ2.1 loss function with support for one-hot labels. """LVQ2.1 loss function with support for one-hot labels.
See Section 4 [Sado&Yamada] See Section 4 [Sado&Yamada]
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
""" """
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = dp - dm mu = dp - dm
return mu
return mu

View File

@@ -1 +1,7 @@
"""ProtoTorch modules.""" """ProtoTorch modules."""
from .prototypes import Prototypes1D
__all__ = [
"Prototypes1D",
]

View File

@@ -1,9 +1,11 @@
import torch import torch
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.distances import euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from torch import nn from torch import nn
from prototorch.functions.distances import euclidean_distance_matrix, tangent_distance
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D
class GTLVQ(nn.Module): class GTLVQ(nn.Module):
r""" Generalized Tangent Learning Vector Quantization r""" Generalized Tangent Learning Vector Quantization
@@ -77,35 +79,45 @@ class GTLVQ(nn.Module):
super(GTLVQ, self).__init__() super(GTLVQ, self).__init__()
self.num_protos = num_classes * prototypes_per_class self.num_protos = num_classes * prototypes_per_class
self.num_protos_class = prototypes_per_class
self.subspace_size = feature_dim if subspace_size is None else subspace_size self.subspace_size = feature_dim if subspace_size is None else subspace_size
self.feature_dim = feature_dim self.feature_dim = feature_dim
self.num_classes = num_classes
cls_initializer = StratifiedMeanInitializer(prototype_data)
cls_distribution = {
"num_classes": num_classes,
"prototypes_per_class": prototypes_per_class,
}
self.cls = LabeledComponents(cls_distribution, cls_initializer)
if subspace_data is None: if subspace_data is None:
raise ValueError("Init Data must be specified!") raise ValueError("Init Data must be specified!")
self.tpt = tangent_projection_type self.tpt = tangent_projection_type
with torch.no_grad(): with torch.no_grad():
if self.tpt == "local": if self.tpt == "local" or self.tpt == "local_proj":
self.init_local_subspace(subspace_data, subspace_size, self.init_local_subspace(subspace_data)
self.num_protos)
elif self.tpt == "global": elif self.tpt == "global":
self.init_gobal_subspace(subspace_data, subspace_size) self.init_gobal_subspace(subspace_data, subspace_size)
else: else:
self.subspaces = None self.subspaces = None
# Hypothesis-Margin-Classifier
self.cls = Prototypes1D(
input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
nclasses=num_classes,
prototype_initializer="stratified_mean",
data=prototype_data,
)
def forward(self, x): def forward(self, x):
if self.tpt == "local": # Tangent Projection
dis = self.local_tangent_distances(x) if self.tpt == "local_proj":
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2))
dis, proj_x = self.local_tangent_projection(x_conform)
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
self.feature_dim)
return proj_x, dis
elif self.tpt == "local":
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2))
dis = tangent_distance(x_conform, self.cls.prototypes,
self.subspaces)
elif self.tpt == "gloabl": elif self.tpt == "gloabl":
dis = self.global_tangent_distances(x) dis = self.global_tangent_distances(x)
else: else:
@@ -118,14 +130,16 @@ class GTLVQ(nn.Module):
_, _, v = torch.svd(data) _, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = subspace[:, :num_subspaces] subspaces = subspace[:, :num_subspaces]
self.subspaces = nn.Parameter(subspaces, requires_grad=True) self.subspaces = (torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True))
def init_local_subspace(self, data, num_subspaces, num_protos): def init_local_subspace(self, data):
data = data - torch.mean(data, dim=0) _, _, v = torch.svd(data)
_, _, v = torch.svd(data, some=False) inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
v = v[:, :num_subspaces] subspaces = inital_projector.unsqueeze(0).repeat_interleave(
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0) self.num_protos, 0)
self.subspaces = nn.Parameter(subspaces, requires_grad=True) self.subspaces = (torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True))
def global_tangent_distances(self, x): def global_tangent_distances(self, x):
# Tangent Projection # Tangent Projection
@@ -136,26 +150,37 @@ class GTLVQ(nn.Module):
# Euclidean Distance # Euclidean Distance
return euclidean_distance_matrix(x, projected_prototypes) return euclidean_distance_matrix(x, projected_prototypes)
def local_tangent_distances(self, x): def local_tangent_projection(self, signals):
# Note: subspaces is always assumed as transposed and must be orthogonal!
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
# subspace should be orthogonalized
# Origin Source Code
# Origin Author:
protos = self.cls.prototypes
subspaces = self.subspaces
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
_, proto_int_shape = _int_and_mixed_shape(protos)
# Tangent Distance # check if the shapes are correct
x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components, _check_shapes(signal_int_shape, proto_int_shape)
x.size(-1))
protos = self.cls()[0].unsqueeze(0).expand(x.size(0), # Tangent Data Projections
self.cls.num_components, projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
x.size(-1)) data = signals.squeeze(2).permute([1, 0, 2])
projectors = torch.eye( projected_data = torch.bmm(data, subspaces)
self.subspaces.shape[-2], device=x.device) - torch.bmm( projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
self.subspaces, self.subspaces.permute([0, 2, 1])) diff = projected_data - projected_protos
diff = (x - protos) projected_diff = torch.reshape(
diff = diff.permute([1, 0, 2]) diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
diff = torch.bmm(diff, projectors) signal_shape[3:])
diff = torch.norm(diff, 2, dim=-1).T diss = torch.norm(projected_diff, 2, dim=-1)
return diff return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
def get_parameters(self): def get_parameters(self):
return { return {
"params": self.cls.components, "params": self.cls.prototypes,
}, { }, {
"params": self.subspaces "params": self.subspaces
} }

View File

@@ -0,0 +1,137 @@
"""ProtoTorch prototype modules."""
import warnings
import torch
from prototorch.functions.initializers import get_initializer
class _Prototypes(torch.nn.Module):
"""Abstract prototypes class."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _validate_prototype_distribution(self):
if 0 in self.prototype_distribution:
warnings.warn("Are you sure about the `0` in "
"`prototype_distribution`?")
def extra_repr(self):
return f"prototypes.shape: {tuple(self.prototypes.shape)}"
def forward(self):
return self.prototypes, self.prototype_labels
class Prototypes1D(_Prototypes):
"""Create a learnable set of one-dimensional prototypes.
TODO Complete this doc-string.
"""
def __init__(
self,
prototypes_per_class=1,
prototype_initializer="ones",
prototype_distribution=None,
data=None,
dtype=torch.float32,
one_hot_labels=False,
**kwargs,
):
warnings.warn(
PendingDeprecationWarning(
"Prototypes1D will be replaced in future versions."))
# Convert tensors to python lists before processing
if prototype_distribution is not None:
if not isinstance(prototype_distribution, list):
prototype_distribution = prototype_distribution.tolist()
if data is None:
if "input_dim" not in kwargs:
raise NameError("`input_dim` required if "
"no `data` is provided.")
if prototype_distribution:
kwargs_nclasses = sum(prototype_distribution)
else:
if "nclasses" not in kwargs:
raise NameError("`prototype_distribution` required if "
"both `data` and `nclasses` are not "
"provided.")
kwargs_nclasses = kwargs.pop("nclasses")
input_dim = kwargs.pop("input_dim")
if prototype_initializer in [
"stratified_mean", "stratified_random"
]:
warnings.warn(
f"`prototype_initializer`: `{prototype_initializer}` "
"requires `data`, but `data` is not provided. "
"Using randomly generated data instead.")
x_train = torch.rand(kwargs_nclasses, input_dim)
y_train = torch.arange(kwargs_nclasses)
if one_hot_labels:
y_train = torch.eye(kwargs_nclasses)[y_train]
data = [x_train, y_train]
x_train, y_train = data
x_train = torch.as_tensor(x_train).type(dtype)
y_train = torch.as_tensor(y_train).type(torch.int)
nclasses = torch.unique(y_train, dim=-1).shape[-1]
if nclasses == 1:
warnings.warn("Are you sure about having one class only?")
if x_train.ndim != 2:
raise ValueError("`data[0].ndim != 2`.")
if y_train.ndim == 2:
if y_train.shape[1] == 1 and one_hot_labels:
raise ValueError("`one_hot_labels` is set to `True` "
"but target labels are not one-hot-encoded.")
if y_train.shape[1] != 1 and not one_hot_labels:
raise ValueError("`one_hot_labels` is set to `False` "
"but target labels in `data` "
"are one-hot-encoded.")
if y_train.ndim == 1 and one_hot_labels:
raise ValueError("`one_hot_labels` is set to `True` "
"but target labels are not one-hot-encoded.")
# Verify input dimension if `input_dim` is provided
if "input_dim" in kwargs:
input_dim = kwargs.pop("input_dim")
if input_dim != x_train.shape[1]:
raise ValueError(f"Provided `input_dim`={input_dim} does "
"not match data dimension "
f"`data[0].shape[1]`={x_train.shape[1]}")
# Verify the number of classes if `nclasses` is provided
if "nclasses" in kwargs:
kwargs_nclasses = kwargs.pop("nclasses")
if kwargs_nclasses != nclasses:
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
"not match data labels "
"`torch.unique(data[1]).shape[0]`"
f"={nclasses}")
super().__init__(**kwargs)
if not prototype_distribution:
prototype_distribution = [prototypes_per_class] * nclasses
with torch.no_grad():
self.prototype_distribution = torch.tensor(prototype_distribution)
self._validate_prototype_distribution()
self.prototype_initializer = get_initializer(prototype_initializer)
prototypes, prototype_labels = self.prototype_initializer(
x_train,
y_train,
prototype_distribution=self.prototype_distribution,
one_hot=one_hot_labels,
)
# Register module parameters
self.prototypes = torch.nn.Parameter(prototypes)
self.prototype_labels = torch.nn.Parameter(
prototype_labels.type(dtype)).requires_grad_(False)

View File

@@ -20,7 +20,6 @@ INSTALL_REQUIRES = [
"torch>=1.3.1", "torch>=1.3.1",
"torchvision>=0.5.0", "torchvision>=0.5.0",
"numpy>=1.9.1", "numpy>=1.9.1",
"sklearn",
] ]
DATASETS = [ DATASETS = [
"requests", "requests",
@@ -32,9 +31,9 @@ DOCS = [
"sphinx", "sphinx",
"sphinx_rtd_theme", "sphinx_rtd_theme",
"sphinxcontrib-katex", "sphinxcontrib-katex",
"sphinx-autodoc-typehints",
] ]
EXAMPLES = [ EXAMPLES = [
"sklearn",
"matplotlib", "matplotlib",
"torchinfo", "torchinfo",
] ]
@@ -43,7 +42,7 @@ ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
name="prototorch", name="prototorch",
version="0.5.0", version="0.4.2",
description="Highly extensible, GPU-supported " description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox " "Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.", "built using PyTorch and its nn API.",

View File

@@ -1,25 +0,0 @@
"""ProtoTorch components test suite."""
import prototorch as pt
import torch
def test_labcomps_zeros_init():
protos = torch.zeros(3, 2)
c = pt.components.LabeledComponents(
distribution=[1, 1, 1],
initializer=pt.components.Zeros(2),
)
assert (c.components == protos).any() == True
def test_labcomps_warmstart():
protos = torch.randn(3, 2)
plabels = torch.tensor([1, 2, 3])
c = pt.components.LabeledComponents(
distribution=[1, 1, 1],
initializer=None,
initialized_components=[protos, plabels],
)
assert (c.components == protos).any() == True
assert (c.component_labels == plabels).any() == True

View File

@@ -4,8 +4,14 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from prototorch.functions import (activations, competitions, distances,
initializers, losses) from prototorch.functions import (
activations,
competitions,
distances,
initializers,
losses,
)
class TestActivations(unittest.TestCase): class TestActivations(unittest.TestCase):
@@ -138,7 +144,7 @@ class TestCompetitions(unittest.TestCase):
def test_knnc_k1(self): def test_knnc_k1(self):
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]]) d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
labels = torch.tensor([0, 1, 2, 3]) labels = torch.tensor([0, 1, 2, 3])
actual = competitions.knnc(d, labels, k=1) actual = competitions.knnc(d, labels, k=torch.tensor([1]))
desired = torch.tensor([2, 0]) desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,

98
tests/test_kernels.py Normal file
View File

@@ -0,0 +1,98 @@
"""ProtoTorch kernels test suite."""
import unittest
import numpy as np
import torch
from prototorch.functions.distances import KernelDistance
from prototorch.functions.kernels import ExplicitKernel, RadialBasisFunctionKernel
class TestExplicitKernel(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
def test_single_values(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
kernel = ExplicitKernel()
self.assertEqual(
kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
class TestRadialBasisFunctionKernel(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
def test_single_values(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
kernel = RadialBasisFunctionKernel(1)
self.assertEqual(
kernel(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))
class TestKernelDistance(unittest.TestCase):
def setUp(self):
self.single_x = torch.randn(1024)
self.single_y = torch.randn(1024)
self.batch_x = torch.randn(32, 1024)
self.batch_y = torch.randn(32, 1024)
self.kernel = ExplicitKernel()
def test_single_values(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.single_x, self.single_y).shape, torch.Size([]))
def test_single_batch(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.single_x, self.batch_y).shape, torch.Size([32]))
def test_batch_single(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.batch_x, self.single_y).shape, torch.Size([32]))
def test_batch_values(self):
distance = KernelDistance(self.kernel)
self.assertEqual(
distance(self.batch_x, self.batch_y).shape, torch.Size([32, 32]))

298
tests/test_modules.py Normal file
View File

@@ -0,0 +1,298 @@
"""ProtoTorch modules test suite."""
import unittest
import numpy as np
import torch
from prototorch.modules import losses, prototypes
class TestPrototypes(unittest.TestCase):
def setUp(self):
self.x = torch.tensor(
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
dtype=torch.float32)
self.y = torch.tensor([0, 0, 1, 1])
self.gen = torch.manual_seed(42)
def test_prototypes1d_init_without_input_dim(self):
with self.assertRaises(NameError):
_ = prototypes.Prototypes1D(nclasses=2)
def test_prototypes1d_init_without_nclasses(self):
with self.assertRaises(NameError):
_ = prototypes.Prototypes1D(input_dim=1)
def test_prototypes1d_init_with_nclasses_1(self):
with self.assertWarns(UserWarning):
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
def test_prototypes1d_init_without_pdist(self):
p1 = prototypes.Prototypes1D(
input_dim=6,
nclasses=2,
prototypes_per_class=4,
prototype_initializer="ones",
)
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.ones(8, 6)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_init_without_data(self):
pdist = [2, 2]
p1 = prototypes.Prototypes1D(input_dim=3,
prototype_distribution=pdist,
prototype_initializer="zeros")
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.zeros(4, 3)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_proto_init_without_data(self):
with self.assertWarns(UserWarning):
_ = prototypes.Prototypes1D(
input_dim=3,
nclasses=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=None,
)
def test_prototypes1d_init_torch_pdist(self):
pdist = torch.tensor([2, 2])
p1 = prototypes.Prototypes1D(input_dim=3,
prototype_distribution=pdist,
prototype_initializer="zeros")
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.zeros(4, 3)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_init_without_inputdim_with_data(self):
_ = prototypes.Prototypes1D(
nclasses=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[[[1.0], [0.0]], [1, 0]],
)
def test_prototypes1d_init_with_int_data(self):
_ = prototypes.Prototypes1D(
nclasses=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[[[1], [0]], [1, 0]],
)
def test_prototypes1d_init_one_hot_without_data(self):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=None,
one_hot_labels=True,
)
def test_prototypes1d_init_one_hot_labels_false(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
but the provided `data` has one-hot encoded labels.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
one_hot_labels=False,
)
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
but the provided `data` does not contain one-hot encoded labels.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=([[0.0], [1.0]], [0, 1]),
one_hot_labels=True,
)
def test_prototypes1d_init_one_hot_labels_true(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
but the provided `data` contains 2D targets but
does not contain one-hot encoded labels.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=([[0.0], [1.0]], [[0], [1]]),
one_hot_labels=True,
)
def test_prototypes1d_init_with_int_dtype(self):
with self.assertRaises(RuntimeError):
_ = prototypes.Prototypes1D(
nclasses=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[[[1], [0]], [1, 0]],
dtype=torch.int32,
)
def test_prototypes1d_inputndim_with_data(self):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(input_dim=1,
nclasses=1,
prototypes_per_class=1,
data=[[1.0], [1]])
def test_prototypes1d_inputdim_with_data(self):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=2,
nclasses=2,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[[[1.0], [0.0]], [1, 0]],
)
def test_prototypes1d_nclasses_with_data(self):
"""Test ValueError raise if provided `nclasses` is not the same
as the one computed from the provided `data`.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=1,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
data=[[[1.0], [2.0]], [1, 2]],
)
def test_prototypes1d_init_with_ppc(self):
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
prototypes_per_class=2,
prototype_initializer="zeros")
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.zeros(4, 3)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_init_with_pdist(self):
p1 = prototypes.Prototypes1D(
data=[self.x, self.y],
prototype_distribution=[6, 9],
prototype_initializer="zeros",
)
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.zeros(15, 3)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_func_initializer(self):
def my_initializer(*args, **kwargs):
return torch.full((2, 99), 99.0), torch.tensor([0, 1])
p1 = prototypes.Prototypes1D(
input_dim=99,
nclasses=2,
prototypes_per_class=1,
prototype_initializer=my_initializer,
)
protos = p1.prototypes
actual = protos.detach().numpy()
desired = 99 * torch.ones(2, 99)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_forward(self):
p1 = prototypes.Prototypes1D(data=[self.x, self.y])
protos, _ = p1()
actual = protos.detach().numpy()
desired = torch.ones(2, 3)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_dist_validate(self):
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
with self.assertWarns(UserWarning):
_ = p1._validate_prototype_distribution()
def test_prototypes1d_validate_extra_repr_not_empty(self):
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
rep = p1.extra_repr()
self.assertNotEqual(rep, "")
def tearDown(self):
del self.x, self.y, self.gen
_ = torch.seed()
class TestLosses(unittest.TestCase):
def setUp(self):
pass
def test_glvqloss_init(self):
_ = losses.GLVQLoss(0, "swish_beta", beta=20)
def test_glvqloss_forward_1ppc(self):
criterion = losses.GLVQLoss(margin=0,
squashing="sigmoid_beta",
beta=100)
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
labels = torch.tensor([0, 1])
targets = torch.ones(100)
outputs = [d, labels]
loss = criterion(outputs, targets)
loss_value = loss.item()
self.assertAlmostEqual(loss_value, 0.0)
def test_glvqloss_forward_2ppc(self):
criterion = losses.GLVQLoss(margin=0,
squashing="sigmoid_beta",
beta=100)
d = torch.stack([
torch.ones(100),
torch.ones(100),
torch.zeros(100),
torch.ones(100)
],
dim=1)
labels = torch.tensor([0, 0, 1, 1])
targets = torch.ones(100)
outputs = [d, labels]
loss = criterion(outputs, targets)
loss_value = loss.item()
self.assertAlmostEqual(loss_value, 0.0)
def tearDown(self):
pass