Compare commits
90 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
6ffd14e85c | ||
|
40c1021c20 | ||
|
acf3272fd7 | ||
|
c73f8e7a28 | ||
|
bf23d5f7f8 | ||
|
bcde3f6ac8 | ||
|
d5229b1750 | ||
|
fc4b143fbb | ||
|
11cfa79746 | ||
|
d0ae94f2af | ||
|
2c908a8361 | ||
|
e4257ec1f1 | ||
|
aaad2b8626 | ||
|
c0c0044a42 | ||
|
47d7f5831f | ||
|
4f1c879528 | ||
|
2272c55092 | ||
|
b03c9b1d3c | ||
|
0c28eda706 | ||
|
7bc0bfa3ab | ||
|
827958a28a | ||
|
8200e1d3d8 | ||
|
729b20e9ab | ||
|
ca8ac7a43b | ||
|
b724a28a6f | ||
|
1e0a8392a2 | ||
|
2eb7b05653 | ||
|
d8a0b2dfcc | ||
|
2a7394b593 | ||
|
b1e64c8b8b | ||
|
70cf17607e | ||
|
b1568a550a | ||
|
e8e803e8ef | ||
|
2c453265fe | ||
|
7336d35fee | ||
|
bc18952c05 | ||
|
8e8d0b9c2c | ||
|
5a7da2b40b | ||
|
b6d38f442b | ||
|
8e8851d962 | ||
|
27b43b06a7 | ||
|
ff69eb1256 | ||
|
4ca581909a | ||
|
2722d976f5 | ||
|
946cda00d2 | ||
|
8227525c82 | ||
|
e61ae73749 | ||
|
040d1ee9e8 | ||
|
7f0da894fa | ||
|
62726df278 | ||
|
0ba09db6fe | ||
|
87334c11e6 | ||
|
40ef3aeda2 | ||
|
94fe4435a8 | ||
|
c204bc8e1f | ||
|
00615ae837 | ||
|
9f5f0d12dd | ||
|
8a291f7bfb | ||
|
21e3e3b82d | ||
|
a6bd6e130a | ||
|
fcdfa52892 | ||
|
73e6fe384e | ||
|
aff7a385a3 | ||
|
1e23ba05fa | ||
|
ee30d4da5b | ||
|
14508f0600 | ||
|
e3f8828da4 | ||
|
30adbf705c | ||
|
ee42fd68b1 | ||
|
736d9a6349 | ||
|
0055e15bc1 | ||
|
b2e1df7308 | ||
|
b935e9caf3 | ||
|
503ef0e05f | ||
|
dc6248413c | ||
|
e73b70ceb7 | ||
|
639198e774 | ||
|
768d969f89 | ||
|
aec422c277 | ||
|
6c14170de6 | ||
|
36a330aa66 | ||
|
acd4ac6a86 | ||
|
abe64cfe8f | ||
|
caae95d01d | ||
|
088429a16a | ||
|
b6145223c8 | ||
|
09256956f3 | ||
|
0ca90fdcee | ||
|
be21412f8a | ||
|
ae6bc47f87 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.2
|
||||
current_version = 0.5.1
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||
|
8
.github/ISSUE_TEMPLATE/bug_report.md
vendored
8
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -23,9 +23,9 @@ A clear and concise description of what you expected to happen.
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
- OS: [e.g. Ubuntu 20.10]
|
||||
- Prototorch Version: [e.g. v0.4.0]
|
||||
- Python Version: [e.g. 3.9.5]
|
||||
- OS: [e.g. Ubuntu 20.10]
|
||||
- Prototorch Version: [e.g. v0.4.0]
|
||||
- Python Version: [e.g. 3.9.5]
|
||||
|
||||
**Additional context**
|
||||
Add any other context about the problem here.
|
||||
Add any other context about the problem here.
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@@ -154,4 +154,5 @@ scratch*
|
||||
# End of https://www.gitignore.io/api/visualstudiocode
|
||||
.vscode/
|
||||
|
||||
reports
|
||||
reports
|
||||
artifacts
|
||||
|
54
.pre-commit-config.yaml
Normal file
54
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,54 @@
|
||||
# See https://pre-commit.com for more information
|
||||
# See https://pre-commit.com/hooks.html for more hooks
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.0.1
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
- id: check-added-large-files
|
||||
- id: check-ast
|
||||
- id: check-case-conflict
|
||||
|
||||
|
||||
- repo: https://github.com/myint/autoflake
|
||||
rev: v1.4
|
||||
hooks:
|
||||
- id: autoflake
|
||||
|
||||
- repo: http://github.com/PyCQA/isort
|
||||
rev: 5.8.0
|
||||
hooks:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: 'v0.902'
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: prototorch
|
||||
additional_dependencies: [types-pkg_resources]
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||
rev: 'v0.31.0' # Use the sha / tag you want to point at
|
||||
hooks:
|
||||
- id: yapf
|
||||
|
||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||
rev: v1.9.0 # Use the ref you want to point at
|
||||
hooks:
|
||||
- id: python-use-type-annotations
|
||||
- id: python-no-log-warn
|
||||
- id: python-check-blanket-noqa
|
||||
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.19.4
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/jorisroovers/gitlint
|
||||
rev: "v0.15.1"
|
||||
hooks:
|
||||
- id: gitlint
|
||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
@@ -4,7 +4,9 @@ language: python
|
||||
python: 3.8
|
||||
cache:
|
||||
directories:
|
||||
- "$HOME/.cache/pip"
|
||||
- "./tests/artifacts"
|
||||
- "$HOME/datasets"
|
||||
install:
|
||||
- pip install .[all] --progress-bar off
|
||||
|
||||
|
11
README.md
11
README.md
@@ -48,6 +48,17 @@ pip install -e .[all]
|
||||
The documentation is available at <https://www.prototorch.ml/en/latest/>. Should
|
||||
that link not work try <https://prototorch.readthedocs.io/en/latest/>.
|
||||
|
||||
## Contribution
|
||||
|
||||
This repository contains definition for [git hooks](https://githooks.com).
|
||||
[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
|
||||
Please install the hooks by running
|
||||
```bash
|
||||
pre-commit install
|
||||
pre-commit install --hook-type commit-msg
|
||||
```
|
||||
before creating the first commit.
|
||||
|
||||
## Bibtex
|
||||
|
||||
If you would like to cite the package, please use this:
|
||||
|
@@ -1,13 +1,16 @@
|
||||
# ProtoTorch Releases
|
||||
|
||||
## Release 0.5.0
|
||||
|
||||
- Breaking: Removed deprecated `prototorch.modules.Prototypes1D`.
|
||||
- Use `prototorch.components.LabeledComponents` instead.
|
||||
|
||||
## Release 0.2.0
|
||||
|
||||
### Includes
|
||||
- Fixes in example scripts.
|
||||
|
||||
## Release 0.1.1-dev0
|
||||
|
||||
### Includes
|
||||
- Minor bugfixes.
|
||||
- 100% line coverage.
|
||||
|
||||
|
@@ -1,13 +1,24 @@
|
||||
.. ProtoFlow API Reference
|
||||
.. ProtoTorch API Reference
|
||||
|
||||
ProtoFlow API Reference
|
||||
ProtoTorch API Reference
|
||||
======================================
|
||||
|
||||
Datasets
|
||||
--------------------------------------
|
||||
|
||||
Common Datasets
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: prototorch.datasets
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
|
||||
Abstract Datasets
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Abstract Datasets are used to build your own datasets.
|
||||
|
||||
.. autoclass:: prototorch.datasets.abstract.NumpyDataset
|
||||
:members:
|
||||
|
||||
Functions
|
||||
--------------------------------------
|
||||
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
#
|
||||
release = "0.4.2"
|
||||
release = "0.5.1"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
@@ -46,6 +46,7 @@ extensions = [
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib.katex",
|
||||
'sphinx_autodoc_typehints',
|
||||
]
|
||||
|
||||
# katex_prerender = True
|
||||
@@ -179,6 +180,9 @@ texinfo_documents = [
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/", None),
|
||||
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
|
||||
"torch": ('https://pytorch.org/docs/stable/', None),
|
||||
"pytorch_lightning":
|
||||
("https://pytorch-lightning.readthedocs.io/en/stable/", None),
|
||||
}
|
||||
|
||||
# -- Options for Epub output ----------------------------------------------
|
||||
|
@@ -1,120 +0,0 @@
|
||||
"""ProtoTorch GLVQ example using 2D Iris data."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torchinfo import summary
|
||||
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
# Prepare and preprocess the data
|
||||
scaler = StandardScaler()
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
scaler.fit(x_train)
|
||||
x_train = scaler.transform(x_train)
|
||||
|
||||
|
||||
# Define the GLVQ model
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
"""GLVQ model for training on 2D Iris data."""
|
||||
super().__init__()
|
||||
self.proto_layer = Prototypes1D(
|
||||
input_dim=2,
|
||||
prototypes_per_class=3,
|
||||
nclasses=3,
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x_train, y_train],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.proto_layer.prototypes
|
||||
plabels = self.proto_layer.prototype_labels
|
||||
dis = euclidean_distance(x, protos)
|
||||
return dis, plabels
|
||||
|
||||
|
||||
# Build the GLVQ model
|
||||
model = Model()
|
||||
|
||||
# Print summary using torchinfo (might be buggy/incorrect)
|
||||
print(summary(model))
|
||||
|
||||
# Optimize using SGD optimizer from `torch.optim`
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||
|
||||
x_in = torch.Tensor(x_train)
|
||||
y_in = torch.Tensor(y_train)
|
||||
|
||||
# Training loop
|
||||
title = "Prototype Visualization"
|
||||
fig = plt.figure(title)
|
||||
for epoch in range(70):
|
||||
# Compute loss
|
||||
dis, plabels = model(x_in)
|
||||
loss = criterion([dis, plabels], y_in)
|
||||
with torch.no_grad():
|
||||
pred = wtac(dis, plabels)
|
||||
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
||||
acc = 100.0 * correct / len(x_train)
|
||||
print(
|
||||
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
||||
)
|
||||
|
||||
# Take a gradient descent step
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.proto_layer.prototypes.data.numpy()
|
||||
if np.isnan(np.sum(protos)):
|
||||
print("Stopping training because of `nan` in prototypes.")
|
||||
break
|
||||
|
||||
# Visualize the data and the prototypes
|
||||
ax = fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Data dimension 1")
|
||||
ax.set_ylabel("Data dimension 2")
|
||||
cmap = "viridis"
|
||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||
ax.scatter(
|
||||
protos[:, 0],
|
||||
protos[:, 1],
|
||||
c=plabels,
|
||||
cmap=cmap,
|
||||
edgecolor="k",
|
||||
marker="D",
|
||||
s=50,
|
||||
)
|
||||
|
||||
# Paint decision regions
|
||||
x = np.vstack((x_train, protos))
|
||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
||||
np.arange(y_min, y_max, 1 / 50))
|
||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||
|
||||
torch_input = torch.Tensor(mesh_input)
|
||||
d = model(torch_input)[0]
|
||||
w_indices = torch.argmin(d, dim=1)
|
||||
y_pred = torch.index_select(plabels, 0, w_indices)
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
|
||||
# Plot voronoi regions
|
||||
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
|
||||
|
||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||
|
||||
plt.pause(0.1)
|
@@ -1,104 +0,0 @@
|
||||
"""ProtoTorch "siamese" GMLVQ example using Tecator."""
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from prototorch.datasets.tecator import Tecator
|
||||
from prototorch.functions.distances import sed
|
||||
from prototorch.modules import Prototypes1D
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.utils.colors import get_legend_handles
|
||||
|
||||
# Prepare the dataset and dataloader
|
||||
train_data = Tecator(root="./artifacts", train=True)
|
||||
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
"""GMLVQ model as a siamese network."""
|
||||
super().__init__()
|
||||
x, y = train_data.data, train_data.targets
|
||||
self.p1 = Prototypes1D(
|
||||
input_dim=100,
|
||||
prototypes_per_class=2,
|
||||
nclasses=2,
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x, y],
|
||||
)
|
||||
self.omega = torch.nn.Linear(in_features=100,
|
||||
out_features=100,
|
||||
bias=False)
|
||||
torch.nn.init.eye_(self.omega.weight)
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
plabels = self.p1.prototype_labels
|
||||
|
||||
# Process `x` and `protos` through `omega`
|
||||
x_map = self.omega(x)
|
||||
protos_map = self.omega(protos)
|
||||
|
||||
# Compute distances and output
|
||||
dis = sed(x_map, protos_map)
|
||||
return dis, plabels
|
||||
|
||||
|
||||
# Build the GLVQ model
|
||||
model = Model()
|
||||
|
||||
# Print a summary of the model
|
||||
print(model)
|
||||
|
||||
# Optimize using Adam optimizer from `torch.optim`
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001_0)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=75, gamma=0.1)
|
||||
criterion = GLVQLoss(squashing="identity", beta=10)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(150):
|
||||
epoch_loss = 0.0 # zero-out epoch loss
|
||||
optimizer.zero_grad() # zero-out gradients
|
||||
for xb, yb in train_loader:
|
||||
# Compute loss
|
||||
distances, plabels = model(xb)
|
||||
loss = criterion([distances, plabels], yb)
|
||||
epoch_loss += loss.item()
|
||||
# Backprop
|
||||
loss.backward()
|
||||
# Take a gradient descent step
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
lr = optimizer.param_groups[0]["lr"]
|
||||
print(f"Epoch: {epoch + 1:03d} Loss: {epoch_loss:06.02f} lr: {lr:07.06f}")
|
||||
|
||||
# Get the omega matrix form the model
|
||||
omega = model.omega.weight.data.numpy().T
|
||||
|
||||
# Visualize the lambda matrix
|
||||
title = "Lambda Matrix Visualization"
|
||||
fig = plt.figure(title)
|
||||
ax = fig.gca()
|
||||
ax.set_title(title)
|
||||
im = ax.imshow(omega.dot(omega.T), cmap="viridis")
|
||||
plt.show()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.p1.prototypes.data.numpy()
|
||||
plabels = model.p1.prototype_labels
|
||||
|
||||
# Visualize the prototypes
|
||||
title = "Tecator Prototypes"
|
||||
fig = plt.figure(title)
|
||||
ax = fig.gca()
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Spectral frequencies")
|
||||
ax.set_ylabel("Absorption")
|
||||
clabels = ["Class 0 - Low fat", "Class 1 - High fat"]
|
||||
handles, colors = get_legend_handles(clabels, marker="line", zero_indexed=True)
|
||||
for x, y in zip(protos, plabels):
|
||||
ax.plot(x, c=colors[int(y)])
|
||||
ax.legend(handles, clabels)
|
||||
plt.show()
|
@@ -1,184 +0,0 @@
|
||||
"""
|
||||
ProtoTorch GTLVQ example using MNIST data.
|
||||
The GTLVQ is placed as an classification model on
|
||||
top of a CNN, considered as featurer extractor.
|
||||
Initialization of subpsace and prototypes in
|
||||
Siamnese fashion
|
||||
For more info about GTLVQ see:
|
||||
DOI:10.1109/IJCNN.2016.7727534
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
|
||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.models import GTLVQ
|
||||
|
||||
# Parameters and options
|
||||
n_epochs = 50
|
||||
batch_size_train = 64
|
||||
batch_size_test = 1000
|
||||
learning_rate = 0.1
|
||||
momentum = 0.5
|
||||
log_interval = 10
|
||||
cuda = "cuda:1"
|
||||
random_seed = 1
|
||||
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Configures reproducability
|
||||
torch.manual_seed(random_seed)
|
||||
np.random.seed(random_seed)
|
||||
|
||||
# Prepare and preprocess the data
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
torchvision.datasets.MNIST(
|
||||
"./files/",
|
||||
train=True,
|
||||
download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
]),
|
||||
),
|
||||
batch_size=batch_size_train,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
torchvision.datasets.MNIST(
|
||||
"./files/",
|
||||
train=False,
|
||||
download=True,
|
||||
transform=torchvision.transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||
]),
|
||||
),
|
||||
batch_size=batch_size_test,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
|
||||
# Define the GLVQ model plus appropriate feature extractor
|
||||
class CNNGTLVQ(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_classes,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type="local",
|
||||
prototypes_per_class=2,
|
||||
bottleneck_dim=128,
|
||||
):
|
||||
super(CNNGTLVQ, self).__init__()
|
||||
|
||||
# Feature Extractor - Simple CNN
|
||||
self.fe = nn.Sequential(
|
||||
nn.Conv2d(1, 32, 3, 1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, 3, 1),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Dropout(0.25),
|
||||
nn.Flatten(),
|
||||
nn.Linear(9216, bottleneck_dim),
|
||||
nn.Dropout(0.5),
|
||||
nn.LeakyReLU(),
|
||||
nn.LayerNorm(bottleneck_dim),
|
||||
)
|
||||
|
||||
# Forward pass of subspace and prototype initialization data through feature extractor
|
||||
subspace_data = self.fe(subspace_data)
|
||||
prototype_data[0] = self.fe(prototype_data[0])
|
||||
|
||||
# Initialization of GTLVQ
|
||||
self.gtlvq = GTLVQ(
|
||||
num_classes,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type=tangent_projection_type,
|
||||
feature_dim=bottleneck_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Feature Extraction
|
||||
x = self.fe(x)
|
||||
|
||||
# GTLVQ Forward pass
|
||||
dis = self.gtlvq(x)
|
||||
return dis
|
||||
|
||||
|
||||
# Get init data
|
||||
subspace_data = torch.cat(
|
||||
[next(iter(train_loader))[0],
|
||||
next(iter(test_loader))[0]])
|
||||
prototype_data = next(iter(train_loader))
|
||||
|
||||
# Build the CNN GTLVQ model
|
||||
model = CNNGTLVQ(
|
||||
10,
|
||||
subspace_data,
|
||||
prototype_data,
|
||||
tangent_projection_type="local",
|
||||
bottleneck_dim=128,
|
||||
).to(device)
|
||||
|
||||
# Optimize using SGD optimizer from `torch.optim`
|
||||
optimizer = torch.optim.Adam(
|
||||
[{
|
||||
"params": model.fe.parameters()
|
||||
}, {
|
||||
"params": model.gtlvq.parameters()
|
||||
}],
|
||||
lr=learning_rate,
|
||||
)
|
||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(n_epochs):
|
||||
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
||||
model.train()
|
||||
x_train, y_train = x_train.to(device), y_train.to(device)
|
||||
optimizer.zero_grad()
|
||||
|
||||
distances = model(x_train)
|
||||
plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||
|
||||
# Compute loss.
|
||||
loss = criterion([distances, plabels], y_train)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# GTLVQ uses projected SGD, which means to orthogonalize the subspaces after every gradient update.
|
||||
model.gtlvq.orthogonalize_subspace()
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
||||
print(
|
||||
f"Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||
Train Acc: {acc.item():02.02f}")
|
||||
|
||||
# Test
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
for x_test, y_test in test_loader:
|
||||
x_test, y_test = x_test.to(device), y_test.to(device)
|
||||
test_distances = model(torch.tensor(x_test))
|
||||
test_plabels = model.gtlvq.cls.prototype_labels.to(device)
|
||||
i = torch.argmin(test_distances, 1)
|
||||
correct += torch.sum(y_test == test_plabels[i])
|
||||
total += y_test.size(0)
|
||||
print("Accuracy of the network on the test images: %d %%" %
|
||||
(torch.true_divide(correct, total) * 100))
|
||||
|
||||
# Save the model
|
||||
PATH = "./glvq_mnist_model.pth"
|
||||
torch.save(model.state_dict(), PATH)
|
@@ -1,110 +0,0 @@
|
||||
"""ProtoTorch LGMLVQ example using 2D Iris data."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
from prototorch.functions.competitions import stratified_min
|
||||
from prototorch.functions.distances import lomega_distance
|
||||
from prototorch.functions.init import eye_
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
# Prepare training data
|
||||
x_train, y_train = load_iris(True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
|
||||
|
||||
# Define the model
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
"""Local-GMLVQ model."""
|
||||
super().__init__()
|
||||
self.p1 = Prototypes1D(
|
||||
input_dim=2,
|
||||
prototype_distribution=[1, 2, 2],
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x_train, y_train],
|
||||
)
|
||||
omegas = torch.zeros(5, 2, 2)
|
||||
self.omegas = torch.nn.Parameter(omegas)
|
||||
eye_(self.omegas)
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.p1.prototypes
|
||||
plabels = self.p1.prototype_labels
|
||||
omegas = self.omegas
|
||||
dis = lomega_distance(x, protos, omegas)
|
||||
return dis, plabels
|
||||
|
||||
|
||||
# Build the model
|
||||
model = Model()
|
||||
|
||||
# Optimize using Adam optimizer from `torch.optim`
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||
|
||||
x_in = torch.Tensor(x_train)
|
||||
y_in = torch.Tensor(y_train)
|
||||
|
||||
# Training loop
|
||||
title = "Prototype Visualization"
|
||||
fig = plt.figure(title)
|
||||
for epoch in range(100):
|
||||
# Compute loss
|
||||
dis, plabels = model(x_in)
|
||||
loss = criterion([dis, plabels], y_in)
|
||||
y_pred = np.argmin(stratified_min(dis, plabels).detach().numpy(), axis=1)
|
||||
acc = accuracy_score(y_train, y_pred)
|
||||
log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} "
|
||||
log_string += f"Acc: {acc * 100:05.02f}%"
|
||||
print(log_string)
|
||||
|
||||
# Take a gradient descent step
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Get the prototypes form the model
|
||||
protos = model.p1.prototypes.data.numpy()
|
||||
|
||||
# Visualize the data and the prototypes
|
||||
ax = fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel("Data dimension 1")
|
||||
ax.set_ylabel("Data dimension 2")
|
||||
cmap = "viridis"
|
||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||
ax.scatter(
|
||||
protos[:, 0],
|
||||
protos[:, 1],
|
||||
c=plabels,
|
||||
cmap=cmap,
|
||||
edgecolor="k",
|
||||
marker="D",
|
||||
s=50,
|
||||
)
|
||||
|
||||
# Paint decision regions
|
||||
x = np.vstack((x_train, protos))
|
||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
||||
np.arange(y_min, y_max, 1 / 50))
|
||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||
|
||||
d, plabels = model(torch.Tensor(mesh_input))
|
||||
y_pred = np.argmin(stratified_min(d, plabels).detach().numpy(), axis=1)
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
|
||||
# Plot voronoi regions
|
||||
ax.contourf(xx, yy, y_pred, cmap=cmap, alpha=0.35)
|
||||
|
||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||
|
||||
plt.pause(0.1)
|
@@ -1,22 +1,26 @@
|
||||
"""ProtoTorch package."""
|
||||
|
||||
import pkgutil
|
||||
from typing import List
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from . import components, datasets, functions, modules, utils
|
||||
from .datasets import *
|
||||
|
||||
# Core Setup
|
||||
__version__ = "0.4.2"
|
||||
__version__ = "0.5.1"
|
||||
|
||||
__all_core__ = [
|
||||
"datasets",
|
||||
"functions",
|
||||
"modules",
|
||||
"components",
|
||||
"utils",
|
||||
]
|
||||
|
||||
from .datasets import *
|
||||
|
||||
# Plugin Loader
|
||||
import pkgutil
|
||||
|
||||
import pkg_resources
|
||||
|
||||
__path__ = pkgutil.extend_path(__path__, __name__)
|
||||
__path__: List[str] = pkgutil.extend_path(__path__, __name__)
|
||||
|
||||
|
||||
def discover_plugins():
|
||||
|
@@ -1,59 +1,148 @@
|
||||
"""ProtoTorch components modules."""
|
||||
|
||||
import warnings
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||
ComponentsInitializer,
|
||||
EqualLabelsInitializer,
|
||||
UnequalLabelsInitializer,
|
||||
ZeroReasoningsInitializer)
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .initializers import parse_data_arg
|
||||
|
||||
|
||||
def get_labels_object(distribution):
|
||||
if isinstance(distribution, dict):
|
||||
if "num_classes" in distribution.keys():
|
||||
labels = EqualLabelsInitializer(
|
||||
distribution["num_classes"],
|
||||
distribution["prototypes_per_class"])
|
||||
else:
|
||||
clabels = list(distribution.keys())
|
||||
dist = list(distribution.values())
|
||||
labels = UnequalLabelsInitializer(dist, clabels)
|
||||
elif isinstance(distribution, tuple):
|
||||
num_classes, prototypes_per_class = distribution
|
||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||
elif isinstance(distribution, list):
|
||||
labels = UnequalLabelsInitializer(distribution)
|
||||
else:
|
||||
msg = f"`distribution` not understood." \
|
||||
f"You have provided: {distribution=}."
|
||||
raise ValueError(msg)
|
||||
return labels
|
||||
|
||||
|
||||
def _precheck_initializer(initializer):
|
||||
if not isinstance(initializer, ComponentsInitializer):
|
||||
emsg = f"`initializer` has to be some subtype of " \
|
||||
f"{ComponentsInitializer}. " \
|
||||
f"You have provided: {initializer=} instead."
|
||||
raise TypeError(emsg)
|
||||
|
||||
|
||||
class LinearMapping(torch.nn.Module):
|
||||
"""LinearMapping is a learnable Mapping Matrix."""
|
||||
def __init__(self,
|
||||
mapping_shape=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_linearmapping=None):
|
||||
super().__init__()
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
if initialized_linearmapping is not None:
|
||||
self._register_mapping(initialized_linearmapping)
|
||||
if num_components is not None or initializer is not None:
|
||||
wmsg = "Arguments ignored while initializing Components"
|
||||
warnings.warn(wmsg)
|
||||
else:
|
||||
self._initialize_mapping(mapping_shape, initializer)
|
||||
|
||||
@property
|
||||
def mapping_shape(self):
|
||||
return self._omega.shape
|
||||
|
||||
def _register_mapping(self, components):
|
||||
self.register_parameter("_omega", Parameter(components))
|
||||
|
||||
def _initialize_mapping(self, mapping_shape, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
_mapping = initializer.generate(mapping_shape)
|
||||
self._register_mapping(_mapping)
|
||||
|
||||
@property
|
||||
def mapping(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._omega.detach()
|
||||
|
||||
def forward(self):
|
||||
return self._omega
|
||||
|
||||
|
||||
class Components(torch.nn.Module):
|
||||
"""Components is a set of learnable Tensors."""
|
||||
def __init__(self,
|
||||
number_of_components=None,
|
||||
num_components=None,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None,
|
||||
dtype=torch.float32):
|
||||
initialized_components=None):
|
||||
super().__init__()
|
||||
|
||||
# Ignore all initialization settings if initialized_components is given.
|
||||
if initialized_components is not None:
|
||||
self._components = Parameter(initialized_components)
|
||||
if number_of_components is not None or initializer is not None:
|
||||
self._register_components(initialized_components)
|
||||
if num_components is not None or initializer is not None:
|
||||
wmsg = "Arguments ignored while initializing Components"
|
||||
warnings.warn(wmsg)
|
||||
else:
|
||||
self._initialize_components(number_of_components, initializer)
|
||||
self._initialize_components(num_components, initializer)
|
||||
|
||||
def _precheck_initializer(self, initializer):
|
||||
if not isinstance(initializer, ComponentsInitializer):
|
||||
emsg = f"`initializer` has to be some subtype of " \
|
||||
f"{ComponentsInitializer}. " \
|
||||
f"You have provided: {initializer=} instead."
|
||||
raise TypeError(emsg)
|
||||
@property
|
||||
def num_components(self):
|
||||
return len(self._components)
|
||||
|
||||
def _initialize_components(self, number_of_components, initializer):
|
||||
self._precheck_initializer(initializer)
|
||||
self._components = Parameter(
|
||||
initializer.generate(number_of_components))
|
||||
def _register_components(self, components):
|
||||
self.register_parameter("_components", Parameter(components))
|
||||
|
||||
def _initialize_components(self, num_components, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
_components = initializer.generate(num_components)
|
||||
self._register_components(_components)
|
||||
|
||||
def add_components(self,
|
||||
num=1,
|
||||
initializer=None,
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
_components = torch.cat([self._components, initialized_components])
|
||||
else:
|
||||
_precheck_initializer(initializer)
|
||||
_new = initializer.generate(num)
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
|
||||
def remove_components(self, indices=None):
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
self._register_components(_components)
|
||||
return mask
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._components.detach().cpu()
|
||||
return self._components.detach()
|
||||
|
||||
def forward(self):
|
||||
return self._components
|
||||
|
||||
def extra_repr(self):
|
||||
return f"components.shape: {tuple(self._components.shape)}"
|
||||
return f"(components): (shape: {tuple(self._components.shape)})"
|
||||
|
||||
|
||||
class LabeledComponents(Components):
|
||||
@@ -67,43 +156,72 @@ class LabeledComponents(Components):
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
super().__init__(initialized_components=initialized_components[0])
|
||||
self._labels = initialized_components[1]
|
||||
components, component_labels = parse_data_arg(
|
||||
initialized_components)
|
||||
super().__init__(initialized_components=components)
|
||||
self._register_labels(component_labels)
|
||||
else:
|
||||
self._initialize_labels(distribution)
|
||||
super().__init__(number_of_components=len(self._labels),
|
||||
initializer=initializer)
|
||||
labels = get_labels_object(distribution)
|
||||
self.initial_distribution = labels.distribution
|
||||
_labels = labels.generate()
|
||||
super().__init__(len(_labels), initializer=initializer)
|
||||
self._register_labels(_labels)
|
||||
|
||||
def _initialize_components(self, number_of_components, initializer):
|
||||
def _register_labels(self, labels):
|
||||
self.register_buffer("_labels", labels)
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
clabels, counts = torch.unique(self._labels,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
return dict(zip(clabels.tolist(), counts.tolist()))
|
||||
|
||||
def _initialize_components(self, num_components, initializer):
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
self._precheck_initializer(initializer)
|
||||
self._components = Parameter(
|
||||
initializer.generate(number_of_components, self.distribution))
|
||||
_precheck_initializer(initializer)
|
||||
_components = initializer.generate(num_components,
|
||||
self.initial_distribution)
|
||||
self._register_components(_components)
|
||||
else:
|
||||
super()._initialize_components(self, number_of_components,
|
||||
initializer)
|
||||
super()._initialize_components(num_components, initializer)
|
||||
|
||||
def _initialize_labels(self, distribution):
|
||||
if type(distribution) == tuple:
|
||||
num_classes, prototypes_per_class = distribution
|
||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||
elif type(distribution) == list:
|
||||
labels = UnequalLabelsInitializer(distribution)
|
||||
def add_components(self, distribution, initializer):
|
||||
_precheck_initializer(initializer)
|
||||
|
||||
self.distribution = labels.distribution
|
||||
self._labels = labels.generate()
|
||||
# Labels
|
||||
labels = get_labels_object(distribution)
|
||||
new_labels = labels.generate()
|
||||
_labels = torch.cat([self._labels, new_labels])
|
||||
self._register_labels(_labels)
|
||||
|
||||
# Components
|
||||
if isinstance(initializer, ClassAwareInitializer):
|
||||
_new = initializer.generate(len(new_labels), distribution)
|
||||
else:
|
||||
_new = initializer.generate(len(new_labels))
|
||||
_components = torch.cat([self._components, _new])
|
||||
self._register_components(_components)
|
||||
|
||||
def remove_components(self, indices=None):
|
||||
# Components
|
||||
mask = super().remove_components(indices)
|
||||
|
||||
# Labels
|
||||
_labels = self._labels[mask]
|
||||
self._register_labels(_labels)
|
||||
|
||||
@property
|
||||
def component_labels(self):
|
||||
"""Tensor containing the component tensors."""
|
||||
return self._labels.detach().cpu()
|
||||
return self._labels.detach()
|
||||
|
||||
def forward(self):
|
||||
return super().forward(), self._labels
|
||||
|
||||
|
||||
class ReasoningComponents(Components):
|
||||
"""ReasoningComponents generate a set of components and a set of reasoning matrices.
|
||||
r"""ReasoningComponents generate a set of components and a set of reasoning matrices.
|
||||
|
||||
Every Component has a reasoning matrix assigned.
|
||||
|
||||
@@ -123,20 +241,21 @@ class ReasoningComponents(Components):
|
||||
*,
|
||||
initialized_components=None):
|
||||
if initialized_components is not None:
|
||||
super().__init__(initialized_components=initialized_components[0])
|
||||
self._reasonings = initialized_components[1]
|
||||
components, reasonings = initialized_components
|
||||
|
||||
super().__init__(initialized_components=components)
|
||||
self.register_parameter("_reasonings", reasonings)
|
||||
else:
|
||||
self._initialize_reasonings(reasonings)
|
||||
super().__init__(number_of_components=len(self._reasonings),
|
||||
initializer=initializer)
|
||||
super().__init__(len(self._reasonings), initializer=initializer)
|
||||
|
||||
def _initialize_reasonings(self, reasonings):
|
||||
if type(reasonings) == tuple:
|
||||
num_classes, number_of_components = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(num_classes,
|
||||
number_of_components)
|
||||
if isinstance(reasonings, tuple):
|
||||
num_classes, num_components = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(num_classes, num_components)
|
||||
|
||||
self._reasonings = reasonings.generate()
|
||||
_reasonings = reasonings.generate()
|
||||
self.register_parameter("_reasonings", _reasonings)
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
@@ -145,7 +264,7 @@ class ReasoningComponents(Components):
|
||||
Dimension NxCx2
|
||||
|
||||
"""
|
||||
return self._reasonings.detach().cpu()
|
||||
return self._reasonings.detach()
|
||||
|
||||
def forward(self):
|
||||
return super().forward(), self._reasonings
|
||||
|
@@ -7,21 +7,36 @@ import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
|
||||
def parse_init_arg(arg):
|
||||
if isinstance(arg, Dataset):
|
||||
data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
|
||||
# data = data.view(len(arg), -1) # flatten
|
||||
def parse_data_arg(data_arg):
|
||||
if isinstance(data_arg, Dataset):
|
||||
data_arg = DataLoader(data_arg, batch_size=len(data_arg))
|
||||
|
||||
if isinstance(data_arg, DataLoader):
|
||||
data = torch.tensor([])
|
||||
targets = torch.tensor([])
|
||||
for x, y in data_arg:
|
||||
data = torch.cat([data, x])
|
||||
targets = torch.cat([targets, y])
|
||||
else:
|
||||
data, labels = arg
|
||||
data, targets = data_arg
|
||||
if not isinstance(data, torch.Tensor):
|
||||
wmsg = f"Converting data to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
data = torch.Tensor(data)
|
||||
if not isinstance(labels, torch.Tensor):
|
||||
wmsg = f"Converting labels to {torch.Tensor}."
|
||||
if not isinstance(targets, torch.Tensor):
|
||||
wmsg = f"Converting targets to {torch.Tensor}."
|
||||
warnings.warn(wmsg)
|
||||
labels = torch.Tensor(labels)
|
||||
return data, labels
|
||||
targets = torch.Tensor(targets)
|
||||
return data, targets
|
||||
|
||||
|
||||
def get_subinitializers(data, targets, clabels, subinit_type):
|
||||
initializers = dict()
|
||||
for clabel in clabels:
|
||||
class_data = data[targets == clabel]
|
||||
class_initializer = subinit_type(class_data)
|
||||
initializers[clabel] = (class_initializer)
|
||||
return initializers
|
||||
|
||||
|
||||
# Components
|
||||
@@ -31,18 +46,22 @@ class ComponentsInitializer(object):
|
||||
|
||||
|
||||
class DimensionAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, c_dims):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
if isinstance(c_dims, Iterable):
|
||||
self.components_dims = tuple(c_dims)
|
||||
if isinstance(dims, Iterable):
|
||||
self.components_dims = tuple(dims)
|
||||
else:
|
||||
self.components_dims = (c_dims, )
|
||||
self.components_dims = (dims, )
|
||||
|
||||
|
||||
class OnesInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, dims, scale=1.0):
|
||||
super().__init__(dims)
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.ones(gen_dims)
|
||||
return torch.ones(gen_dims) * self.scale
|
||||
|
||||
|
||||
class ZerosInitializer(DimensionAwareInitializer):
|
||||
@@ -52,97 +71,110 @@ class ZerosInitializer(DimensionAwareInitializer):
|
||||
|
||||
|
||||
class UniformInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, c_dims, min=0.0, max=1.0):
|
||||
super().__init__(c_dims)
|
||||
|
||||
self.min = min
|
||||
self.max = max
|
||||
def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0):
|
||||
super().__init__(dims)
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.ones(gen_dims).uniform_(self.min, self.max)
|
||||
return torch.ones(gen_dims).uniform_(self.minimum,
|
||||
self.maximum) * self.scale
|
||||
|
||||
|
||||
class PositionAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, positions):
|
||||
class DataAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
super().__init__()
|
||||
self.data = positions
|
||||
self.data = data
|
||||
self.transform = transform
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class SelectionInitializer(PositionAwareInitializer):
|
||||
class SelectionInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||
return self.data[indices]
|
||||
return self.transform(self.data[indices])
|
||||
|
||||
|
||||
class MeanInitializer(PositionAwareInitializer):
|
||||
class MeanInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
mean = torch.mean(self.data, dim=0)
|
||||
repeat_dim = [length] + [1] * len(mean.shape)
|
||||
return mean.repeat(repeat_dim)
|
||||
return self.transform(mean.repeat(repeat_dim))
|
||||
|
||||
|
||||
class ClassAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, arg):
|
||||
super().__init__()
|
||||
data, labels = parse_init_arg(arg)
|
||||
self.data = data
|
||||
self.labels = labels
|
||||
|
||||
self.clabels = torch.unique(self.labels)
|
||||
class ClassAwareInitializer(DataAwareInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
data, targets = parse_data_arg(data)
|
||||
super().__init__(data, transform)
|
||||
self.targets = targets
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
def _get_samples_from_initializer(self, length, dist):
|
||||
if not dist:
|
||||
per_class = length // self.num_classes
|
||||
dist = self.num_classes * [per_class]
|
||||
samples_list = [
|
||||
init.generate(n) for init, n in zip(self.initializers, dist)
|
||||
]
|
||||
return torch.vstack(samples_list)
|
||||
dist = dict(zip(self.clabels, self.num_classes * [per_class]))
|
||||
if isinstance(dist, list):
|
||||
dist = dict(zip(self.clabels, dist))
|
||||
samples = [self.initializers[k].generate(n) for k, n in dist.items()]
|
||||
out = torch.vstack(samples)
|
||||
with torch.no_grad():
|
||||
out = self.transform(out)
|
||||
return out
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
del self.targets
|
||||
|
||||
|
||||
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||
def __init__(self, arg):
|
||||
super().__init__(arg)
|
||||
def __init__(self, data, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
self.initializers = get_subinitializers(self.data, self.targets,
|
||||
self.clabels, MeanInitializer)
|
||||
|
||||
self.initializers = []
|
||||
for clabel in self.clabels:
|
||||
class_data = self.data[self.labels == clabel]
|
||||
class_initializer = MeanInitializer(class_data)
|
||||
self.initializers.append(class_initializer)
|
||||
|
||||
def generate(self, length, dist=[]):
|
||||
def generate(self, length, dist):
|
||||
samples = self._get_samples_from_initializer(length, dist)
|
||||
return samples
|
||||
|
||||
|
||||
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||
def __init__(self, arg, *, noise=None):
|
||||
super().__init__(arg)
|
||||
def __init__(self, data, noise=None, **kwargs):
|
||||
super().__init__(data, **kwargs)
|
||||
self.noise = noise
|
||||
self.initializers = get_subinitializers(self.data, self.targets,
|
||||
self.clabels,
|
||||
SelectionInitializer)
|
||||
|
||||
self.initializers = []
|
||||
for clabel in self.clabels:
|
||||
class_data = self.data[self.labels == clabel]
|
||||
class_initializer = SelectionInitializer(class_data)
|
||||
self.initializers.append(class_initializer)
|
||||
def add_noise_v1(self, x):
|
||||
return x + self.noise
|
||||
|
||||
def add_noise(self, x):
|
||||
def add_noise_v2(self, x):
|
||||
"""Shifts some dimensions of the data randomly."""
|
||||
n1 = torch.rand_like(x)
|
||||
n2 = torch.rand_like(x)
|
||||
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
|
||||
return x + (self.noise * mask)
|
||||
|
||||
def generate(self, length, dist=[]):
|
||||
def generate(self, length, dist):
|
||||
samples = self._get_samples_from_initializer(length, dist)
|
||||
if self.noise is not None:
|
||||
# samples = self.add_noise(samples)
|
||||
samples = samples + self.noise
|
||||
samples = self.add_noise_v1(samples)
|
||||
return samples
|
||||
|
||||
|
||||
# Omega matrix
|
||||
class PcaInitializer(DataAwareInitializer):
|
||||
def generate(self, shape):
|
||||
(input_dim, latent_dim) = shape
|
||||
(_, eigVal, eigVec) = torch.pca_lowrank(self.data, q=latent_dim)
|
||||
return eigVec
|
||||
|
||||
|
||||
# Labels
|
||||
class LabelsInitializer:
|
||||
def generate(self):
|
||||
@@ -150,17 +182,18 @@ class LabelsInitializer:
|
||||
|
||||
|
||||
class UnequalLabelsInitializer(LabelsInitializer):
|
||||
def __init__(self, dist):
|
||||
def __init__(self, dist, clabels=None):
|
||||
self.dist = dist
|
||||
self.clabels = clabels or range(len(self.dist))
|
||||
|
||||
@property
|
||||
def distribution(self):
|
||||
return self.dist
|
||||
|
||||
def generate(self):
|
||||
clabels = range(len(self.dist))
|
||||
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
|
||||
return torch.tensor(labels)
|
||||
targets = list(
|
||||
chain(*[[i] * n for i, n in zip(self.clabels, self.dist)]))
|
||||
return torch.LongTensor(targets)
|
||||
|
||||
|
||||
class EqualLabelsInitializer(LabelsInitializer):
|
||||
@@ -195,3 +228,6 @@ class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
||||
SMI = StratifiedMeanInitializer
|
||||
Random = RandomInitializer = UniformInitializer
|
||||
Zeros = ZerosInitializer
|
||||
Ones = OnesInitializer
|
||||
PCA = PcaInitializer
|
||||
|
@@ -1,11 +1,6 @@
|
||||
"""ProtoTorch datasets."""
|
||||
|
||||
from .abstract import NumpyDataset
|
||||
from .sklearn import Blobs, Circles, Iris, Moons, Random
|
||||
from .spiral import Spiral
|
||||
from .tecator import Tecator
|
||||
|
||||
__all__ = [
|
||||
"NumpyDataset",
|
||||
"Spiral",
|
||||
"Tecator",
|
||||
]
|
||||
|
@@ -14,8 +14,10 @@ import torch
|
||||
|
||||
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||
def __init__(self, *arrays):
|
||||
tensors = [torch.Tensor(arr) for arr in arrays]
|
||||
def __init__(self, data, targets):
|
||||
self.data = torch.Tensor(data)
|
||||
self.targets = torch.LongTensor(targets)
|
||||
tensors = [self.data, self.targets]
|
||||
super().__init__(*tensors)
|
||||
|
||||
|
||||
|
137
prototorch/datasets/sklearn.py
Normal file
137
prototorch/datasets/sklearn.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Thin wrappers for a few scikit-learn datasets.
|
||||
|
||||
URL:
|
||||
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets
|
||||
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Sequence, Union
|
||||
|
||||
from sklearn.datasets import (load_iris, make_blobs, make_circles,
|
||||
make_classification, make_moons)
|
||||
|
||||
from prototorch.datasets.abstract import NumpyDataset
|
||||
|
||||
|
||||
class Iris(NumpyDataset):
|
||||
"""Iris Dataset by Ronald Fisher introduced in 1936.
|
||||
|
||||
The dataset contains four measurements from flowers of three species of iris.
|
||||
|
||||
.. list-table:: Iris
|
||||
:header-rows: 1
|
||||
|
||||
* - dimensions
|
||||
- classes
|
||||
- training size
|
||||
- validation size
|
||||
- test size
|
||||
* - 4
|
||||
- 3
|
||||
- 150
|
||||
- 0
|
||||
- 0
|
||||
|
||||
:param dims: select a subset of dimensions
|
||||
"""
|
||||
def __init__(self, dims: Sequence[int] = None):
|
||||
x, y = load_iris(return_X_y=True)
|
||||
if dims:
|
||||
x = x[:, dims]
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
class Blobs(NumpyDataset):
|
||||
"""Generate isotropic Gaussian blobs for clustering.
|
||||
|
||||
Read more at
|
||||
https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators.
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 300,
|
||||
num_features: int = 2,
|
||||
seed: Union[None, int] = 0):
|
||||
x, y = make_blobs(num_samples,
|
||||
num_features,
|
||||
centers=None,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
class Random(NumpyDataset):
|
||||
"""Generate a random n-class classification problem.
|
||||
|
||||
Read more at
|
||||
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html.
|
||||
|
||||
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 300,
|
||||
num_features: int = 2,
|
||||
num_classes: int = 2,
|
||||
num_clusters: int = 2,
|
||||
num_informative: Union[None, int] = None,
|
||||
separation: float = 1.0,
|
||||
seed: Union[None, int] = 0):
|
||||
if not num_informative:
|
||||
import math
|
||||
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
||||
if num_features < num_informative:
|
||||
warnings.warn("Generating more features than requested.")
|
||||
num_features = num_informative
|
||||
x, y = make_classification(num_samples,
|
||||
num_features,
|
||||
n_informative=num_informative,
|
||||
n_redundant=0,
|
||||
n_classes=num_classes,
|
||||
n_clusters_per_class=num_clusters,
|
||||
class_sep=separation,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
class Circles(NumpyDataset):
|
||||
"""Make a large circle containing a smaller circle in 2D.
|
||||
|
||||
A simple toy dataset to visualize clustering and classification algorithms.
|
||||
|
||||
Read more at
|
||||
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 300,
|
||||
noise: float = 0.3,
|
||||
factor: float = 0.8,
|
||||
seed: Union[None, int] = 0):
|
||||
x, y = make_circles(num_samples,
|
||||
noise=noise,
|
||||
factor=factor,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
super().__init__(x, y)
|
||||
|
||||
|
||||
class Moons(NumpyDataset):
|
||||
"""Make two interleaving half circles.
|
||||
|
||||
A simple toy dataset to visualize clustering and classification algorithms.
|
||||
|
||||
Read more at
|
||||
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 300,
|
||||
noise: float = 0.3,
|
||||
seed: Union[None, int] = 0):
|
||||
x, y = make_moons(num_samples,
|
||||
noise=noise,
|
||||
random_state=seed,
|
||||
shuffle=False)
|
||||
super().__init__(x, y)
|
@@ -4,18 +4,22 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def make_spiral(n_samples=500, noise=0.3):
|
||||
def make_spiral(num_samples=500, noise=0.3):
|
||||
"""Generates the Spiral Dataset.
|
||||
|
||||
For use in Prototorch use `prototorch.datasets.Spiral` instead.
|
||||
"""
|
||||
def get_samples(n, delta_t):
|
||||
points = []
|
||||
for i in range(n):
|
||||
r = i / n_samples * 5
|
||||
r = i / num_samples * 5
|
||||
t = 1.75 * i / n * 2 * np.pi + delta_t
|
||||
x = r * np.sin(t) + np.random.rand(1) * noise
|
||||
y = r * np.cos(t) + np.random.rand(1) * noise
|
||||
points.append([x, y])
|
||||
return points
|
||||
|
||||
n = n_samples // 2
|
||||
n = num_samples // 2
|
||||
positive = get_samples(n=n, delta_t=0)
|
||||
negative = get_samples(n=n, delta_t=np.pi)
|
||||
x = np.concatenate(
|
||||
@@ -27,7 +31,27 @@ def make_spiral(n_samples=500, noise=0.3):
|
||||
|
||||
|
||||
class Spiral(torch.utils.data.TensorDataset):
|
||||
"""Spiral dataset for binary classification."""
|
||||
def __init__(self, n_samples=500, noise=0.3):
|
||||
x, y = make_spiral(n_samples, noise)
|
||||
"""Spiral dataset for binary classification.
|
||||
|
||||
This datasets consists of two spirals of two different classes.
|
||||
|
||||
.. list-table:: Spiral
|
||||
:header-rows: 1
|
||||
|
||||
* - dimensions
|
||||
- classes
|
||||
- training size
|
||||
- validation size
|
||||
- test size
|
||||
* - 2
|
||||
- 2
|
||||
- num_samples
|
||||
- 0
|
||||
- 0
|
||||
|
||||
:param num_samples: number of random samples
|
||||
:param noise: noise added to the spirals
|
||||
"""
|
||||
def __init__(self, num_samples: int = 500, noise: float = 0.3):
|
||||
x, y = make_spiral(num_samples, noise)
|
||||
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
||||
|
@@ -47,8 +47,23 @@ from prototorch.datasets.abstract import ProtoDataset
|
||||
|
||||
class Tecator(ProtoDataset):
|
||||
"""
|
||||
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__
|
||||
for classification.
|
||||
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__ for classification.
|
||||
|
||||
The dataset contains wavelength measurements of meat.
|
||||
|
||||
.. list-table:: Tecator
|
||||
:header-rows: 1
|
||||
|
||||
* - dimensions
|
||||
- classes
|
||||
- training size
|
||||
- validation size
|
||||
- test size
|
||||
* - 100
|
||||
- 2
|
||||
- 129
|
||||
- 43
|
||||
- 43
|
||||
"""
|
||||
|
||||
_resources = [
|
||||
@@ -87,12 +102,12 @@ class Tecator(ProtoDataset):
|
||||
x_train, y_train = f["x_train"], f["y_train"]
|
||||
x_test, y_test = f["x_test"], f["y_test"]
|
||||
training_set = [
|
||||
torch.tensor(x_train, dtype=torch.float32),
|
||||
torch.tensor(y_train),
|
||||
torch.Tensor(x_train),
|
||||
torch.LongTensor(y_train),
|
||||
]
|
||||
test_set = [
|
||||
torch.tensor(x_test, dtype=torch.float32),
|
||||
torch.tensor(y_test),
|
||||
torch.Tensor(x_test),
|
||||
torch.LongTensor(y_test),
|
||||
]
|
||||
|
||||
with open(os.path.join(self.processed_folder, self.training_file),
|
||||
|
@@ -2,11 +2,4 @@
|
||||
|
||||
from .activations import identity, sigmoid_beta, swish_beta
|
||||
from .competitions import knnc, wtac
|
||||
|
||||
__all__ = [
|
||||
"identity",
|
||||
"sigmoid_beta",
|
||||
"swish_beta",
|
||||
"knnc",
|
||||
"wtac",
|
||||
]
|
||||
from .pooling import *
|
||||
|
@@ -5,17 +5,14 @@ import torch
|
||||
ACTIVATIONS = dict()
|
||||
|
||||
|
||||
# def register_activation(scriptf):
|
||||
# ACTIVATIONS[scriptf.name] = scriptf
|
||||
# return scriptf
|
||||
def register_activation(function):
|
||||
def register_activation(fn):
|
||||
"""Add the activation function to the registry."""
|
||||
ACTIVATIONS[function.__name__] = function
|
||||
return function
|
||||
name = fn.__name__
|
||||
ACTIVATIONS[name] = fn
|
||||
return fn
|
||||
|
||||
|
||||
@register_activation
|
||||
# @torch.jit.script
|
||||
def identity(x, beta=0.0):
|
||||
"""Identity activation function.
|
||||
|
||||
@@ -29,7 +26,6 @@ def identity(x, beta=0.0):
|
||||
|
||||
|
||||
@register_activation
|
||||
# @torch.jit.script
|
||||
def sigmoid_beta(x, beta=10.0):
|
||||
r"""Sigmoid activation function with scaling.
|
||||
|
||||
@@ -44,7 +40,6 @@ def sigmoid_beta(x, beta=10.0):
|
||||
|
||||
|
||||
@register_activation
|
||||
# @torch.jit.script
|
||||
def swish_beta(x, beta=10.0):
|
||||
r"""Swish activation function with scaling.
|
||||
|
||||
|
@@ -3,43 +3,26 @@
|
||||
import torch
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def stratified_min(distances, labels):
|
||||
clabels = torch.unique(labels, dim=0)
|
||||
nclasses = clabels.size()[0]
|
||||
if distances.size()[1] == nclasses:
|
||||
# skip if only one prototype per class
|
||||
return distances
|
||||
batch_size = distances.size()[0]
|
||||
winning_distances = torch.zeros(nclasses, batch_size)
|
||||
inf = torch.full_like(distances.T, fill_value=float("inf"))
|
||||
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||
for i, cl in enumerate(clabels):
|
||||
# cdists = distances.T[labels == cl]
|
||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
cdists = torch.where(matcher, distances.T, inf).T
|
||||
winning_distances[i] = torch.min(cdists, dim=1,
|
||||
keepdim=True).values.squeeze()
|
||||
if labels.ndim == 2:
|
||||
# Transpose to return with `batch_size` first and
|
||||
# reverse the columns to fix the ordering of the classes
|
||||
return torch.flip(winning_distances.T, dims=(1, ))
|
||||
def wtac(distances: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.LongTensor):
|
||||
"""Winner-Takes-All-Competition.
|
||||
|
||||
return winning_distances.T # return with `batch_size` first
|
||||
Returns the labels corresponding to the winners.
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def wtac(distances, labels):
|
||||
"""
|
||||
winning_indices = torch.min(distances, dim=1).indices
|
||||
winning_labels = labels[winning_indices].squeeze()
|
||||
return winning_labels
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def knnc(distances, labels, k):
|
||||
winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
|
||||
winning_labels = labels[winning_indices].squeeze()
|
||||
def knnc(distances: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
k: int = 1) -> (torch.LongTensor):
|
||||
"""K-Nearest-Neighbors-Competition.
|
||||
|
||||
Returns the labels corresponding to the winners.
|
||||
|
||||
"""
|
||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
||||
winning_labels = torch.mode(labels[winning_indices], dim=1).values
|
||||
return winning_labels
|
||||
|
@@ -4,7 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||
equal_int_shape)
|
||||
equal_int_shape, get_flat)
|
||||
|
||||
|
||||
def squared_euclidean_distance(x, y):
|
||||
@@ -12,12 +12,10 @@ def squared_euclidean_distance(x, y):
|
||||
|
||||
Compute :math:`{\langle \bm x - \bm y \rangle}_2`
|
||||
|
||||
:param `torch.tensor` x: Two dimensional vector
|
||||
:param `torch.tensor` y: Two dimensional vector
|
||||
|
||||
**Alias:**
|
||||
``prototorch.functions.distances.sed``
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
expanded_x = x.unsqueeze(dim=1)
|
||||
batchwise_difference = y - expanded_x
|
||||
differences_raised = torch.pow(batchwise_difference, 2)
|
||||
@@ -30,18 +28,17 @@ def euclidean_distance(x, y):
|
||||
|
||||
Compute :math:`\sqrt{{\langle \bm x - \bm y \rangle}_2}`
|
||||
|
||||
:param `torch.tensor` x: Input Tensor of shape :math:`X \times N`
|
||||
:param `torch.tensor` y: Input Tensor of shape :math:`Y \times N`
|
||||
|
||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
distances_raised = squared_euclidean_distance(x, y)
|
||||
distances = torch.sqrt(distances_raised)
|
||||
return distances
|
||||
|
||||
|
||||
def euclidean_distance_v2(x, y):
|
||||
x, y = get_flat(x, y)
|
||||
diff = y - x.unsqueeze(1)
|
||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||
@@ -62,10 +59,9 @@ def lpnorm_distance(x, y, p):
|
||||
|
||||
Calls ``torch.cdist``
|
||||
|
||||
:param `torch.tensor` x: Two dimensional vector
|
||||
:param `torch.tensor` y: Two dimensional vector
|
||||
:param p: p parameter of the lp norm
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
distances = torch.cdist(x, y, p=p)
|
||||
return distances
|
||||
|
||||
@@ -75,10 +71,9 @@ def omega_distance(x, y, omega):
|
||||
|
||||
Compute :math:`{\| \Omega \bm x - \Omega \bm y \|}_p`
|
||||
|
||||
:param `torch.tensor` x: Two dimensional vector
|
||||
:param `torch.tensor` y: Two dimensional vector
|
||||
:param `torch.tensor` omega: Two dimensional matrix
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
projected_x = x @ omega
|
||||
projected_y = y @ omega
|
||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||
@@ -90,10 +85,9 @@ def lomega_distance(x, y, omegas):
|
||||
|
||||
Compute :math:`{\| \Omega_k \bm x - \Omega_k \bm y_k \|}_p`
|
||||
|
||||
:param `torch.tensor` x: Two dimensional vector
|
||||
:param `torch.tensor` y: Two dimensional vector
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
"""
|
||||
x, y = get_flat(x, y)
|
||||
projected_x = x @ omegas
|
||||
projected_y = torch.diagonal(y @ omegas).T
|
||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||
|
@@ -1,6 +1,11 @@
|
||||
import torch
|
||||
|
||||
|
||||
def get_flat(*args):
|
||||
rv = [x.view(x.size(0), -1) for x in args]
|
||||
return rv
|
||||
|
||||
|
||||
def calculate_prototype_accuracy(y_pred, y_true, plabels):
|
||||
"""Computes the accuracy of a prototype based model.
|
||||
via Winner-Takes-All rule.
|
||||
@@ -84,6 +89,6 @@ def _check_shapes(signal_int_shape, proto_int_shape):
|
||||
|
||||
def _int_and_mixed_shape(tensor):
|
||||
shape = mixed_shape(tensor)
|
||||
int_shape = tuple([i if isinstance(i, int) else None for i in shape])
|
||||
int_shape = tuple(i if isinstance(i, int) else None for i in shape)
|
||||
|
||||
return shape, int_shape
|
||||
|
@@ -15,59 +15,59 @@ def register_initializer(function):
|
||||
|
||||
def labels_from(distribution, one_hot=True):
|
||||
"""Takes a distribution tensor and returns a labels tensor."""
|
||||
nclasses = distribution.shape[0]
|
||||
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
||||
num_classes = distribution.shape[0]
|
||||
llist = [[i] * n for i, n in zip(range(num_classes), distribution)]
|
||||
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
||||
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
||||
plabels = torch.tensor(flat_llist, requires_grad=False)
|
||||
if one_hot:
|
||||
return torch.eye(nclasses)[plabels]
|
||||
return torch.eye(num_classes)[plabels]
|
||||
return plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.ones(nprotos, *x_train.shape[1:])
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.ones(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.zeros(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.rand(nprotos, *x_train.shape[1:])
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.rand(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.randn(nprotos, *x_train.shape[1:])
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.randn(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(nprotos, pdim)
|
||||
protos = torch.empty(num_protos, pdim)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
for i, label in enumerate(plabels):
|
||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||
if one_hot:
|
||||
nclasses = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
num_classes = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
xl = x_train[matcher]
|
||||
mean_xl = torch.mean(xl, dim=0)
|
||||
protos[i] = mean_xl
|
||||
@@ -81,15 +81,15 @@ def stratified_random(x_train,
|
||||
prototype_distribution,
|
||||
one_hot=True,
|
||||
epsilon=1e-7):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(nprotos, pdim)
|
||||
protos = torch.empty(num_protos, pdim)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
for i, label in enumerate(plabels):
|
||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||
if one_hot:
|
||||
nclasses = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
num_classes = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
xl = x_train[matcher]
|
||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||
random_xl = xl[rand_index]
|
||||
|
@@ -8,12 +8,12 @@ def _get_matcher(targets, labels):
|
||||
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
nclasses = targets.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
num_classes = targets.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
return matcher
|
||||
|
||||
|
||||
def _get_dp_dm(distances, targets, plabels):
|
||||
def _get_dp_dm(distances, targets, plabels, with_indices=False):
|
||||
"""Returns the d+ and d- values for a batch of distances."""
|
||||
matcher = _get_matcher(targets, plabels)
|
||||
not_matcher = torch.bitwise_not(matcher)
|
||||
@@ -21,9 +21,11 @@ def _get_dp_dm(distances, targets, plabels):
|
||||
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||
d_matching = torch.where(matcher, distances, inf)
|
||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
||||
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
|
||||
return dp, dm
|
||||
dp = torch.min(d_matching, dim=-1, keepdim=True)
|
||||
dm = torch.min(d_unmatching, dim=-1, keepdim=True)
|
||||
if with_indices:
|
||||
return dp, dm
|
||||
return dp.values, dm.values
|
||||
|
||||
|
||||
def glvq_loss(distances, target_labels, prototype_labels):
|
||||
@@ -47,10 +49,46 @@ def lvq1_loss(distances, target_labels, prototype_labels):
|
||||
|
||||
def lvq21_loss(distances, target_labels, prototype_labels):
|
||||
"""LVQ2.1 loss function with support for one-hot labels.
|
||||
|
||||
|
||||
See Section 4 [Sado&Yamada]
|
||||
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
|
||||
"""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = dp - dm
|
||||
return mu
|
||||
|
||||
return mu
|
||||
|
||||
|
||||
# Probabilistic
|
||||
def _get_class_probabilities(probabilities, targets, prototype_labels):
|
||||
# Create Label Mapping
|
||||
uniques = prototype_labels.unique(sorted=True).tolist()
|
||||
key_val = {key: val for key, val in zip(uniques, range(len(uniques)))}
|
||||
|
||||
target_indices = torch.LongTensor(list(map(key_val.get, targets.tolist())))
|
||||
|
||||
whole = probabilities.sum(dim=1)
|
||||
correct = probabilities[torch.arange(len(probabilities)), target_indices]
|
||||
wrong = whole - correct
|
||||
|
||||
return whole, correct, wrong
|
||||
|
||||
|
||||
def nllr_loss(probabilities, targets, prototype_labels):
|
||||
"""Compute the Negative Log-Likelihood Ratio loss."""
|
||||
_, correct, wrong = _get_class_probabilities(probabilities, targets,
|
||||
prototype_labels)
|
||||
|
||||
likelihood = correct / wrong
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return -1.0 * log_likelihood
|
||||
|
||||
|
||||
def rslvq_loss(probabilities, targets, prototype_labels):
|
||||
"""Compute the Robust Soft Learning Vector Quantization (RSLVQ) loss."""
|
||||
whole, correct, _ = _get_class_probabilities(probabilities, targets,
|
||||
prototype_labels)
|
||||
|
||||
likelihood = correct / whole
|
||||
log_likelihood = torch.log(likelihood)
|
||||
return -1.0 * log_likelihood
|
||||
|
80
prototorch/functions/pooling.py
Normal file
80
prototorch/functions/pooling.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""ProtoTorch pooling functions."""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def stratify_with(values: torch.Tensor,
|
||||
labels: torch.LongTensor,
|
||||
fn: Callable,
|
||||
fill_value: float = 0.0) -> (torch.Tensor):
|
||||
"""Apply an arbitrary stratification strategy on the columns on `values`.
|
||||
|
||||
The outputs correspond to sorted labels.
|
||||
"""
|
||||
clabels = torch.unique(labels, dim=0, sorted=True)
|
||||
num_classes = clabels.size()[0]
|
||||
if values.size()[1] == num_classes:
|
||||
# skip if stratification is trivial
|
||||
return values
|
||||
batch_size = values.size()[0]
|
||||
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
||||
filler = torch.full_like(values.T, fill_value=fill_value)
|
||||
for i, cl in enumerate(clabels):
|
||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
cdists = torch.where(matcher, values.T, filler).T
|
||||
winning_values[i] = fn(cdists)
|
||||
if labels.ndim == 2:
|
||||
# Transpose to return with `batch_size` first and
|
||||
# reverse the columns to fix the ordering of the classes
|
||||
return torch.flip(winning_values.T, dims=(1, ))
|
||||
|
||||
return winning_values.T # return with `batch_size` first
|
||||
|
||||
|
||||
def stratified_sum_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise sum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.sum(x, dim=1, keepdim=True).squeeze(),
|
||||
fill_value=0.0)
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_min_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise minimum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.min(x, dim=1, keepdim=True).values.squeeze(),
|
||||
fill_value=float("inf"))
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_max_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise maximum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.max(x, dim=1, keepdim=True).values.squeeze(),
|
||||
fill_value=-1.0 * float("inf"))
|
||||
return winning_values
|
||||
|
||||
|
||||
def stratified_prod_pooling(values: torch.Tensor,
|
||||
labels: torch.LongTensor) -> (torch.Tensor):
|
||||
"""Group-wise maximum."""
|
||||
winning_values = stratify_with(
|
||||
values,
|
||||
labels,
|
||||
fn=lambda x: torch.prod(x, dim=1, keepdim=True).squeeze(),
|
||||
fill_value=1.0)
|
||||
return winning_values
|
32
prototorch/functions/transforms.py
Normal file
32
prototorch/functions/transforms.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
|
||||
|
||||
# Functions
|
||||
def gaussian(distances, variance):
|
||||
return torch.exp(-(distances * distances) / (2 * variance))
|
||||
|
||||
|
||||
def rank_scaled_gaussian(distances, lambd):
|
||||
order = torch.argsort(distances, dim=1)
|
||||
ranks = torch.argsort(order, dim=1)
|
||||
|
||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||
|
||||
|
||||
# Modules
|
||||
class GaussianPrior(torch.nn.Module):
|
||||
def __init__(self, variance):
|
||||
super().__init__()
|
||||
self.variance = variance
|
||||
|
||||
def forward(self, distances):
|
||||
return gaussian(distances, self.variance)
|
||||
|
||||
|
||||
class RankScaledGaussianPrior(torch.nn.Module):
|
||||
def __init__(self, lambd):
|
||||
super().__init__()
|
||||
self.lambd = lambd
|
||||
|
||||
def forward(self, distances):
|
||||
return rank_scaled_gaussian(distances, self.lambd)
|
@@ -1,7 +1,5 @@
|
||||
"""ProtoTorch modules."""
|
||||
|
||||
from .prototypes import Prototypes1D
|
||||
|
||||
__all__ = [
|
||||
"Prototypes1D",
|
||||
]
|
||||
from .competitions import *
|
||||
from .pooling import *
|
||||
from .wrappers import LambdaLayer, LossLayer
|
||||
|
42
prototorch/modules/competitions.py
Normal file
42
prototorch/modules/competitions.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""ProtoTorch Competition Modules."""
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.competitions import knnc, wtac
|
||||
|
||||
|
||||
class WTAC(torch.nn.Module):
|
||||
"""Winner-Takes-All-Competition Layer.
|
||||
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, distances, labels):
|
||||
return wtac(distances, labels)
|
||||
|
||||
|
||||
class LTAC(torch.nn.Module):
|
||||
"""Loser-Takes-All-Competition Layer.
|
||||
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, probs, labels):
|
||||
return wtac(-1.0 * probs, labels)
|
||||
|
||||
|
||||
class KNNC(torch.nn.Module):
|
||||
"""K-Nearest-Neighbors-Competition.
|
||||
|
||||
Thin wrapper over the `knnc` function.
|
||||
|
||||
"""
|
||||
def __init__(self, k=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.k = k
|
||||
|
||||
def forward(self, distances, labels):
|
||||
return knnc(distances, labels, k=self.k)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"k: {self.k}"
|
@@ -21,8 +21,8 @@ class GLVQLoss(torch.nn.Module):
|
||||
|
||||
|
||||
class NeuralGasEnergy(torch.nn.Module):
|
||||
def __init__(self, lm):
|
||||
super().__init__()
|
||||
def __init__(self, lm, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.lm = lm
|
||||
|
||||
def forward(self, d):
|
||||
@@ -38,3 +38,22 @@ class NeuralGasEnergy(torch.nn.Module):
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, lm):
|
||||
return torch.exp(-rankings / lm)
|
||||
|
||||
|
||||
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
||||
def __init__(self, topology_layer, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.topology_layer = topology_layer
|
||||
|
||||
@staticmethod
|
||||
def _nghood_fn(rankings, topology):
|
||||
winner = rankings[:, 0]
|
||||
|
||||
weights = torch.zeros_like(rankings, dtype=torch.float)
|
||||
weights[torch.arange(rankings.shape[0]), winner] = 1.0
|
||||
|
||||
neighbours = topology.get_neighbours(winner)
|
||||
|
||||
weights[neighbours] = 0.1
|
||||
|
||||
return weights
|
||||
|
@@ -1,11 +1,9 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||||
tangent_distance)
|
||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||
from prototorch.functions.distances import euclidean_distance_matrix
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
@@ -80,45 +78,35 @@ class GTLVQ(nn.Module):
|
||||
super(GTLVQ, self).__init__()
|
||||
|
||||
self.num_protos = num_classes * prototypes_per_class
|
||||
self.num_protos_class = prototypes_per_class
|
||||
self.subspace_size = feature_dim if subspace_size is None else subspace_size
|
||||
self.feature_dim = feature_dim
|
||||
self.num_classes = num_classes
|
||||
|
||||
cls_initializer = StratifiedMeanInitializer(prototype_data)
|
||||
cls_distribution = {
|
||||
"num_classes": num_classes,
|
||||
"prototypes_per_class": prototypes_per_class,
|
||||
}
|
||||
|
||||
self.cls = LabeledComponents(cls_distribution, cls_initializer)
|
||||
|
||||
if subspace_data is None:
|
||||
raise ValueError("Init Data must be specified!")
|
||||
|
||||
self.tpt = tangent_projection_type
|
||||
with torch.no_grad():
|
||||
if self.tpt == "local" or self.tpt == "local_proj":
|
||||
self.init_local_subspace(subspace_data)
|
||||
if self.tpt == "local":
|
||||
self.init_local_subspace(subspace_data, subspace_size,
|
||||
self.num_protos)
|
||||
elif self.tpt == "global":
|
||||
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||
else:
|
||||
self.subspaces = None
|
||||
|
||||
# Hypothesis-Margin-Classifier
|
||||
self.cls = Prototypes1D(
|
||||
input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=prototype_data,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Tangent Projection
|
||||
if self.tpt == "local_proj":
|
||||
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2))
|
||||
dis, proj_x = self.local_tangent_projection(x_conform)
|
||||
|
||||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
||||
self.feature_dim)
|
||||
return proj_x, dis
|
||||
elif self.tpt == "local":
|
||||
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||
1).unsqueeze(2))
|
||||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
||||
self.subspaces)
|
||||
if self.tpt == "local":
|
||||
dis = self.local_tangent_distances(x)
|
||||
elif self.tpt == "gloabl":
|
||||
dis = self.global_tangent_distances(x)
|
||||
else:
|
||||
@@ -131,16 +119,14 @@ class GTLVQ(nn.Module):
|
||||
_, _, v = torch.svd(data)
|
||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
subspaces = subspace[:, :num_subspaces]
|
||||
self.subspaces = (torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True))
|
||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||
|
||||
def init_local_subspace(self, data):
|
||||
_, _, v = torch.svd(data)
|
||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
||||
self.num_protos, 0)
|
||||
self.subspaces = (torch.nn.Parameter(
|
||||
subspaces).clone().detach().requires_grad_(True))
|
||||
def init_local_subspace(self, data, num_subspaces, num_protos):
|
||||
data = data - torch.mean(data, dim=0)
|
||||
_, _, v = torch.svd(data, some=False)
|
||||
v = v[:, :num_subspaces]
|
||||
subspaces = v.unsqueeze(0).repeat_interleave(num_protos, 0)
|
||||
self.subspaces = nn.Parameter(subspaces, requires_grad=True)
|
||||
|
||||
def global_tangent_distances(self, x):
|
||||
# Tangent Projection
|
||||
@@ -151,37 +137,26 @@ class GTLVQ(nn.Module):
|
||||
# Euclidean Distance
|
||||
return euclidean_distance_matrix(x, projected_prototypes)
|
||||
|
||||
def local_tangent_projection(self, signals):
|
||||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
||||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||
# shape(subspaces): (optional [proto_number]) x prod(dim1 * dim2 * ... * dimN) x prod(projected_atom_shape)
|
||||
# subspace should be orthogonalized
|
||||
# Origin Source Code
|
||||
# Origin Author:
|
||||
protos = self.cls.prototypes
|
||||
subspaces = self.subspaces
|
||||
signal_shape, signal_int_shape = _int_and_mixed_shape(signals)
|
||||
_, proto_int_shape = _int_and_mixed_shape(protos)
|
||||
def local_tangent_distances(self, x):
|
||||
|
||||
# check if the shapes are correct
|
||||
_check_shapes(signal_int_shape, proto_int_shape)
|
||||
|
||||
# Tangent Data Projections
|
||||
projected_protos = torch.bmm(protos.unsqueeze(1), subspaces).squeeze(1)
|
||||
data = signals.squeeze(2).permute([1, 0, 2])
|
||||
projected_data = torch.bmm(data, subspaces)
|
||||
projected_data = projected_data.permute([1, 0, 2]).unsqueeze(1)
|
||||
diff = projected_data - projected_protos
|
||||
projected_diff = torch.reshape(
|
||||
diff, (signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||
signal_shape[3:])
|
||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||
return diss.permute([1, 0, 2]).squeeze(-1), projected_data.squeeze(1)
|
||||
# Tangent Distance
|
||||
x = x.unsqueeze(1).expand(x.size(0), self.cls.num_components,
|
||||
x.size(-1))
|
||||
protos = self.cls()[0].unsqueeze(0).expand(x.size(0),
|
||||
self.cls.num_components,
|
||||
x.size(-1))
|
||||
projectors = torch.eye(
|
||||
self.subspaces.shape[-2], device=x.device) - torch.bmm(
|
||||
self.subspaces, self.subspaces.permute([0, 2, 1]))
|
||||
diff = (x - protos)
|
||||
diff = diff.permute([1, 0, 2])
|
||||
diff = torch.bmm(diff, projectors)
|
||||
diff = torch.norm(diff, 2, dim=-1).T
|
||||
return diff
|
||||
|
||||
def get_parameters(self):
|
||||
return {
|
||||
"params": self.cls.prototypes,
|
||||
"params": self.cls.components,
|
||||
}, {
|
||||
"params": self.subspaces
|
||||
}
|
||||
|
32
prototorch/modules/pooling.py
Normal file
32
prototorch/modules/pooling.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""ProtoTorch Pooling Modules."""
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.pooling import (stratified_max_pooling,
|
||||
stratified_min_pooling,
|
||||
stratified_prod_pooling,
|
||||
stratified_sum_pooling)
|
||||
|
||||
|
||||
class StratifiedSumPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_sum_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedProdPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_prod_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMinPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_min_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMaxPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
return stratified_max_pooling(values, labels)
|
@@ -1,137 +0,0 @@
|
||||
"""ProtoTorch prototype modules."""
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
|
||||
|
||||
class _Prototypes(torch.nn.Module):
|
||||
"""Abstract prototypes class."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _validate_prototype_distribution(self):
|
||||
if 0 in self.prototype_distribution:
|
||||
warnings.warn("Are you sure about the `0` in "
|
||||
"`prototype_distribution`?")
|
||||
|
||||
def extra_repr(self):
|
||||
return f"prototypes.shape: {tuple(self.prototypes.shape)}"
|
||||
|
||||
def forward(self):
|
||||
return self.prototypes, self.prototype_labels
|
||||
|
||||
|
||||
class Prototypes1D(_Prototypes):
|
||||
"""Create a learnable set of one-dimensional prototypes.
|
||||
|
||||
TODO Complete this doc-string.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="ones",
|
||||
prototype_distribution=None,
|
||||
data=None,
|
||||
dtype=torch.float32,
|
||||
one_hot_labels=False,
|
||||
**kwargs,
|
||||
):
|
||||
warnings.warn(
|
||||
PendingDeprecationWarning(
|
||||
"Prototypes1D will be replaced in future versions."))
|
||||
|
||||
# Convert tensors to python lists before processing
|
||||
if prototype_distribution is not None:
|
||||
if not isinstance(prototype_distribution, list):
|
||||
prototype_distribution = prototype_distribution.tolist()
|
||||
|
||||
if data is None:
|
||||
if "input_dim" not in kwargs:
|
||||
raise NameError("`input_dim` required if "
|
||||
"no `data` is provided.")
|
||||
if prototype_distribution:
|
||||
kwargs_nclasses = sum(prototype_distribution)
|
||||
else:
|
||||
if "nclasses" not in kwargs:
|
||||
raise NameError("`prototype_distribution` required if "
|
||||
"both `data` and `nclasses` are not "
|
||||
"provided.")
|
||||
kwargs_nclasses = kwargs.pop("nclasses")
|
||||
input_dim = kwargs.pop("input_dim")
|
||||
if prototype_initializer in [
|
||||
"stratified_mean", "stratified_random"
|
||||
]:
|
||||
warnings.warn(
|
||||
f"`prototype_initializer`: `{prototype_initializer}` "
|
||||
"requires `data`, but `data` is not provided. "
|
||||
"Using randomly generated data instead.")
|
||||
x_train = torch.rand(kwargs_nclasses, input_dim)
|
||||
y_train = torch.arange(kwargs_nclasses)
|
||||
if one_hot_labels:
|
||||
y_train = torch.eye(kwargs_nclasses)[y_train]
|
||||
data = [x_train, y_train]
|
||||
|
||||
x_train, y_train = data
|
||||
x_train = torch.as_tensor(x_train).type(dtype)
|
||||
y_train = torch.as_tensor(y_train).type(torch.int)
|
||||
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
||||
|
||||
if nclasses == 1:
|
||||
warnings.warn("Are you sure about having one class only?")
|
||||
|
||||
if x_train.ndim != 2:
|
||||
raise ValueError("`data[0].ndim != 2`.")
|
||||
|
||||
if y_train.ndim == 2:
|
||||
if y_train.shape[1] == 1 and one_hot_labels:
|
||||
raise ValueError("`one_hot_labels` is set to `True` "
|
||||
"but target labels are not one-hot-encoded.")
|
||||
if y_train.shape[1] != 1 and not one_hot_labels:
|
||||
raise ValueError("`one_hot_labels` is set to `False` "
|
||||
"but target labels in `data` "
|
||||
"are one-hot-encoded.")
|
||||
if y_train.ndim == 1 and one_hot_labels:
|
||||
raise ValueError("`one_hot_labels` is set to `True` "
|
||||
"but target labels are not one-hot-encoded.")
|
||||
|
||||
# Verify input dimension if `input_dim` is provided
|
||||
if "input_dim" in kwargs:
|
||||
input_dim = kwargs.pop("input_dim")
|
||||
if input_dim != x_train.shape[1]:
|
||||
raise ValueError(f"Provided `input_dim`={input_dim} does "
|
||||
"not match data dimension "
|
||||
f"`data[0].shape[1]`={x_train.shape[1]}")
|
||||
|
||||
# Verify the number of classes if `nclasses` is provided
|
||||
if "nclasses" in kwargs:
|
||||
kwargs_nclasses = kwargs.pop("nclasses")
|
||||
if kwargs_nclasses != nclasses:
|
||||
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
|
||||
"not match data labels "
|
||||
"`torch.unique(data[1]).shape[0]`"
|
||||
f"={nclasses}")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not prototype_distribution:
|
||||
prototype_distribution = [prototypes_per_class] * nclasses
|
||||
with torch.no_grad():
|
||||
self.prototype_distribution = torch.tensor(prototype_distribution)
|
||||
|
||||
self._validate_prototype_distribution()
|
||||
|
||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
||||
prototypes, prototype_labels = self.prototype_initializer(
|
||||
x_train,
|
||||
y_train,
|
||||
prototype_distribution=self.prototype_distribution,
|
||||
one_hot=one_hot_labels,
|
||||
)
|
||||
|
||||
# Register module parameters
|
||||
self.prototypes = torch.nn.Parameter(prototypes)
|
||||
self.prototype_labels = torch.nn.Parameter(
|
||||
prototype_labels.type(dtype)).requires_grad_(False)
|
36
prototorch/modules/wrappers.py
Normal file
36
prototorch/modules/wrappers.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""ProtoTorch Wrappers."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LambdaLayer(torch.nn.Module):
|
||||
def __init__(self, fn, name=None):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.name = name or fn.__name__ # lambda fns get <lambda>
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
def extra_repr(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class LossLayer(torch.nn.modules.loss._Loss):
|
||||
def __init__(self,
|
||||
fn,
|
||||
name=None,
|
||||
size_average=None,
|
||||
reduce=None,
|
||||
reduction: str = "mean") -> None:
|
||||
super().__init__(size_average=size_average,
|
||||
reduce=reduce,
|
||||
reduction=reduction)
|
||||
self.fn = fn
|
||||
self.name = name or fn.__name__ # lambda fns get <lambda>
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
def extra_repr(self):
|
||||
return self.name
|
@@ -1,243 +0,0 @@
|
||||
"""Utilities that provide various small functionalities."""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from time import time
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def progressbar(title, value, end, bar_width=20):
|
||||
percent = float(value) / end
|
||||
arrow = "=" * int(round(percent * bar_width) - 1) + ">"
|
||||
spaces = "." * (bar_width - len(arrow))
|
||||
sys.stdout.write("\r{}: [{}] {}%".format(title, arrow + spaces,
|
||||
int(round(percent * 100))))
|
||||
sys.stdout.flush()
|
||||
if percent == 1.0:
|
||||
print()
|
||||
|
||||
|
||||
def prettify_string(inputs, start="", sep=" ", end="\n"):
|
||||
outputs = start + " ".join(inputs.split()) + end
|
||||
return outputs
|
||||
|
||||
|
||||
def pretty_print(inputs):
|
||||
print(prettify_string(inputs))
|
||||
|
||||
|
||||
def writelog(self, *logs, logdir="./logs", logfile="run.txt"):
|
||||
f = os.path.join(logdir, logfile)
|
||||
with open(f, "a+") as fh:
|
||||
for log in logs:
|
||||
fh.write(log)
|
||||
fh.write("\n")
|
||||
|
||||
|
||||
def start_tensorboard(self, logdir="./logs"):
|
||||
cmd = f"tensorboard --logdir={logdir} --port=6006"
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def make_directory(save_dir):
|
||||
if not os.path.exists(save_dir):
|
||||
print(f"Making directory {save_dir}.")
|
||||
os.mkdir(save_dir)
|
||||
|
||||
|
||||
def make_gif(filenames, duration, output_file=None):
|
||||
try:
|
||||
import imageio
|
||||
except ModuleNotFoundError as e:
|
||||
print("Please install Protoflow with [other] extra requirements.")
|
||||
raise (e)
|
||||
|
||||
images = list()
|
||||
for filename in filenames:
|
||||
images.append(imageio.imread(filename))
|
||||
if not output_file:
|
||||
output_file = f"makegif.gif"
|
||||
if images:
|
||||
imageio.mimwrite(output_file, images, duration=duration)
|
||||
|
||||
|
||||
def gif_from_dir(directory,
|
||||
duration,
|
||||
prefix="",
|
||||
output_file=None,
|
||||
verbose=True):
|
||||
images = os.listdir(directory)
|
||||
if verbose:
|
||||
print(f"Making gif from {len(images)} images under {directory}.")
|
||||
filenames = list()
|
||||
# Sort images
|
||||
images = sorted(
|
||||
images,
|
||||
key=lambda img: int(os.path.splitext(img)[0].replace(prefix, "")))
|
||||
for image in images:
|
||||
fname = os.path.join(directory, image)
|
||||
filenames.append(fname)
|
||||
if not output_file:
|
||||
output_file = os.path.join(directory, "makegif.gif")
|
||||
make_gif(filenames=filenames, duration=duration, output_file=output_file)
|
||||
|
||||
|
||||
def accuracy_score(y_true, y_pred):
|
||||
accuracy = np.sum(y_true == y_pred)
|
||||
normalized_acc = accuracy / float(len(y_true))
|
||||
return normalized_acc
|
||||
|
||||
|
||||
def predict_and_score(clf,
|
||||
x_test,
|
||||
y_test,
|
||||
verbose=False,
|
||||
title="Test accuracy"):
|
||||
y_pred = clf.predict(x_test)
|
||||
accuracy = np.sum(y_test == y_pred)
|
||||
normalized_acc = accuracy / float(len(y_test))
|
||||
if verbose:
|
||||
print(f"{title}: {normalized_acc * 100:06.04f}%")
|
||||
return normalized_acc
|
||||
|
||||
|
||||
def remove_nan_rows(arr):
|
||||
"""Remove all rows with `nan` values in `arr`."""
|
||||
mask = np.isnan(arr).any(axis=1)
|
||||
return arr[~mask]
|
||||
|
||||
|
||||
def remove_nan_cols(arr):
|
||||
"""Remove all columns with `nan` values in `arr`."""
|
||||
mask = np.isnan(arr).any(axis=0)
|
||||
return arr[~mask]
|
||||
|
||||
|
||||
def replace_in(arr, replacement_dict, inplace=False):
|
||||
"""Replace the keys found in `arr` with the values from
|
||||
the `replacement_dict`.
|
||||
"""
|
||||
if inplace:
|
||||
new_arr = arr
|
||||
else:
|
||||
import copy
|
||||
|
||||
new_arr = copy.deepcopy(arr)
|
||||
for k, v in replacement_dict.items():
|
||||
new_arr[arr == k] = v
|
||||
return new_arr
|
||||
|
||||
|
||||
def train_test_split(data, train=0.7, val=0.15, shuffle=None, return_xy=False):
|
||||
"""Split a classification dataset in such a way so as to
|
||||
preserve the class distribution in subsamples of the dataset.
|
||||
"""
|
||||
if train + val > 1.0:
|
||||
raise ValueError("Invalid split values for train and val.")
|
||||
Y = data[:, -1]
|
||||
labels = set(Y)
|
||||
hist = dict()
|
||||
for l in labels:
|
||||
data_l = data[Y == l]
|
||||
nl = len(data_l)
|
||||
nl_train = int(nl * train)
|
||||
nl_val = int(nl * val)
|
||||
nl_test = nl - (nl_train + nl_val)
|
||||
hist[l] = (nl_train, nl_val, nl_test)
|
||||
|
||||
train_data = list()
|
||||
val_data = list()
|
||||
test_data = list()
|
||||
for l, (nl_train, nl_val, nl_test) in hist.items():
|
||||
data_l = data[Y == l]
|
||||
if shuffle:
|
||||
np.random.shuffle(data_l)
|
||||
train_l = data_l[:nl_train]
|
||||
val_l = data_l[nl_train:nl_train + nl_val]
|
||||
test_l = data_l[nl_train + nl_val:nl_train + nl_val + nl_test]
|
||||
train_data.append(train_l)
|
||||
val_data.append(val_l)
|
||||
test_data.append(test_l)
|
||||
|
||||
def _squash(data_list):
|
||||
data = np.array(data_list[0])
|
||||
for item in data_list[1:]:
|
||||
data = np.vstack((data, np.array(item)))
|
||||
return data
|
||||
|
||||
train_data = _squash(train_data)
|
||||
if val_data:
|
||||
val_data = _squash(val_data)
|
||||
if test_data:
|
||||
test_data = _squash(test_data)
|
||||
if return_xy:
|
||||
x_train = train_data[:, :-1]
|
||||
y_train = train_data[:, -1]
|
||||
x_val = val_data[:, :-1]
|
||||
y_val = val_data[:, -1]
|
||||
x_test = test_data[:, :-1]
|
||||
y_test = test_data[:, -1]
|
||||
return (x_train, y_train), (x_val, y_val), (x_test, y_test)
|
||||
return train_data, val_data, test_data
|
||||
|
||||
|
||||
def class_histogram(data, title="Untitled"):
|
||||
plt.figure(title)
|
||||
plt.clf()
|
||||
plt.title(title)
|
||||
dist, counts = np.unique(data[:, -1], return_counts=True)
|
||||
plt.bar(dist, counts)
|
||||
plt.xticks(dist)
|
||||
print("Call matplotlib.pyplot.show() to see the plot.")
|
||||
|
||||
|
||||
def ntimer(n=10):
|
||||
"""Wraps a function which wraps another function to time it."""
|
||||
if n < 1:
|
||||
raise (Exception(f"Invalid n = {n} given."))
|
||||
|
||||
def timer(func):
|
||||
"""Wraps `func` with a timer and returns the wrapped `func`."""
|
||||
def wrapper(*args, **kwargs):
|
||||
rv = None
|
||||
before = time()
|
||||
for _ in range(n):
|
||||
rv = func(*args, **kwargs)
|
||||
after = time()
|
||||
elapsed = after - before
|
||||
print(f"Elapsed: {elapsed*1e3:02.02f} ms")
|
||||
return rv
|
||||
|
||||
return wrapper
|
||||
|
||||
return timer
|
||||
|
||||
|
||||
def memoize(verbose=True):
|
||||
"""Wraps a function which wraps another function that memoizes."""
|
||||
def memoizer(func):
|
||||
"""Memoize (cache) return values of `func`.
|
||||
Wraps `func` and returns the wrapped `func` so that `func`
|
||||
is executed when the results are not available in the cache.
|
||||
"""
|
||||
cache = {}
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
t = (pickle.dumps(args), pickle.dumps(kwargs))
|
||||
if t not in cache:
|
||||
if verbose:
|
||||
print(f"Adding NEW rv {func.__name__}{args}{kwargs} "
|
||||
"to cache.")
|
||||
cache[t] = func(*args, **kwargs)
|
||||
else:
|
||||
if verbose:
|
||||
print(f"Using OLD rv {func.__name__}{args}{kwargs} "
|
||||
"from cache.")
|
||||
return cache[t]
|
||||
|
||||
return wrapper
|
||||
|
||||
return memoizer
|
26
setup.py
26
setup.py
@@ -1,10 +1,12 @@
|
||||
"""
|
||||
_____ _ _______ _
|
||||
| __ \ | | |__ __| | |
|
||||
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
|
||||
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
|
||||
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|
||||
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|
|
||||
|
||||
######
|
||||
# # ##### #### ##### #### ##### #### ##### #### # #
|
||||
# # # # # # # # # # # # # # # # # #
|
||||
###### # # # # # # # # # # # # # ######
|
||||
# ##### # # # # # # # # ##### # # #
|
||||
# # # # # # # # # # # # # # # # #
|
||||
# # # #### # #### # #### # # #### # #
|
||||
|
||||
ProtoTorch Core Package
|
||||
"""
|
||||
@@ -18,22 +20,26 @@ with open("README.md", "r") as fh:
|
||||
|
||||
INSTALL_REQUIRES = [
|
||||
"torch>=1.3.1",
|
||||
"torchvision>=0.5.0",
|
||||
"torchvision>=0.5.1",
|
||||
"numpy>=1.9.1",
|
||||
"sklearn",
|
||||
]
|
||||
DATASETS = [
|
||||
"requests",
|
||||
"tqdm",
|
||||
]
|
||||
DEV = ["bumpversion"]
|
||||
DEV = [
|
||||
"bumpversion",
|
||||
"pre-commit",
|
||||
]
|
||||
DOCS = [
|
||||
"recommonmark",
|
||||
"sphinx",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib-katex",
|
||||
"sphinx-autodoc-typehints",
|
||||
]
|
||||
EXAMPLES = [
|
||||
"sklearn",
|
||||
"matplotlib",
|
||||
"torchinfo",
|
||||
]
|
||||
@@ -42,7 +48,7 @@ ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
||||
|
||||
setup(
|
||||
name="prototorch",
|
||||
version="0.4.2",
|
||||
version="0.5.1",
|
||||
description="Highly extensible, GPU-supported "
|
||||
"Learning Vector Quantization (LVQ) toolbox "
|
||||
"built using PyTorch and its nn API.",
|
||||
|
26
tests/test_components.py
Normal file
26
tests/test_components.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""ProtoTorch components test suite."""
|
||||
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
def test_labcomps_zeros_init():
|
||||
protos = torch.zeros(3, 2)
|
||||
c = pt.components.LabeledComponents(
|
||||
distribution=[1, 1, 1],
|
||||
initializer=pt.components.Zeros(2),
|
||||
)
|
||||
assert (c.components == protos).any() == True
|
||||
|
||||
|
||||
def test_labcomps_warmstart():
|
||||
protos = torch.randn(3, 2)
|
||||
plabels = torch.tensor([1, 2, 3])
|
||||
c = pt.components.LabeledComponents(
|
||||
distribution=[1, 1, 1],
|
||||
initializer=None,
|
||||
initialized_components=[protos, plabels],
|
||||
)
|
||||
assert (c.components == protos).any() == True
|
||||
assert (c.component_labels == plabels).any() == True
|
@@ -6,7 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions import (activations, competitions, distances,
|
||||
initializers, losses)
|
||||
initializers, losses, pooling)
|
||||
|
||||
|
||||
class TestActivations(unittest.TestCase):
|
||||
@@ -105,10 +105,28 @@ class TestCompetitions(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_knnc_k1(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
actual = competitions.knnc(d, labels, k=1)
|
||||
desired = torch.tensor([2, 0])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestPooling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_stratified_min(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
actual = competitions.stratified_min(d, labels)
|
||||
actual = pooling.stratified_min_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@@ -119,28 +137,70 @@ class TestCompetitions(unittest.TestCase):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
actual = competitions.stratified_min(d, labels)
|
||||
actual = pooling.stratified_min_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min_simple(self):
|
||||
def test_stratified_min_trivial(self):
|
||||
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 1, 2])
|
||||
actual = competitions.stratified_min(d, labels)
|
||||
actual = pooling.stratified_min_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_knnc_k1(self):
|
||||
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
|
||||
desired = torch.tensor([2, 0])
|
||||
def test_stratified_max(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
actual = pooling.stratified_max_pooling(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_max_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 2, 1, 0])
|
||||
labels = torch.nn.functional.one_hot(labels, num_classes=3)
|
||||
actual = pooling.stratified_max_pooling(d, labels)
|
||||
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_sum(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.LongTensor([0, 0, 1, 2])
|
||||
actual = pooling.stratified_sum_pooling(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_sum_one_hot(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
actual = pooling.stratified_sum_pooling(d, labels)
|
||||
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_prod(self):
|
||||
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
||||
labels = torch.tensor([0, 0, 3, 2, 0])
|
||||
actual = pooling.stratified_prod_pooling(d, labels)
|
||||
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
|
@@ -1,298 +0,0 @@
|
||||
"""ProtoTorch modules test suite."""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.modules import losses, prototypes
|
||||
|
||||
|
||||
class TestPrototypes(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.x = torch.tensor(
|
||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
||||
dtype=torch.float32)
|
||||
self.y = torch.tensor([0, 0, 1, 1])
|
||||
self.gen = torch.manual_seed(42)
|
||||
|
||||
def test_prototypes1d_init_without_input_dim(self):
|
||||
with self.assertRaises(NameError):
|
||||
_ = prototypes.Prototypes1D(nclasses=2)
|
||||
|
||||
def test_prototypes1d_init_without_nclasses(self):
|
||||
with self.assertRaises(NameError):
|
||||
_ = prototypes.Prototypes1D(input_dim=1)
|
||||
|
||||
def test_prototypes1d_init_with_nclasses_1(self):
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
|
||||
|
||||
def test_prototypes1d_init_without_pdist(self):
|
||||
p1 = prototypes.Prototypes1D(
|
||||
input_dim=6,
|
||||
nclasses=2,
|
||||
prototypes_per_class=4,
|
||||
prototype_initializer="ones",
|
||||
)
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.ones(8, 6)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_prototypes1d_init_without_data(self):
|
||||
pdist = [2, 2]
|
||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||
prototype_distribution=pdist,
|
||||
prototype_initializer="zeros")
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(4, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_prototypes1d_proto_init_without_data(self):
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=3,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=None,
|
||||
)
|
||||
|
||||
def test_prototypes1d_init_torch_pdist(self):
|
||||
pdist = torch.tensor([2, 2])
|
||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||
prototype_distribution=pdist,
|
||||
prototype_initializer="zeros")
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(4, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_prototypes1d_init_without_inputdim_with_data(self):
|
||||
_ = prototypes.Prototypes1D(
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[[[1.0], [0.0]], [1, 0]],
|
||||
)
|
||||
|
||||
def test_prototypes1d_init_with_int_data(self):
|
||||
_ = prototypes.Prototypes1D(
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[[[1], [0]], [1, 0]],
|
||||
)
|
||||
|
||||
def test_prototypes1d_init_one_hot_without_data(self):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=None,
|
||||
one_hot_labels=True,
|
||||
)
|
||||
|
||||
def test_prototypes1d_init_one_hot_labels_false(self):
|
||||
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
|
||||
but the provided `data` has one-hot encoded labels.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
|
||||
one_hot_labels=False,
|
||||
)
|
||||
|
||||
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
|
||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
||||
but the provided `data` does not contain one-hot encoded labels.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=([[0.0], [1.0]], [0, 1]),
|
||||
one_hot_labels=True,
|
||||
)
|
||||
|
||||
def test_prototypes1d_init_one_hot_labels_true(self):
|
||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
||||
but the provided `data` contains 2D targets but
|
||||
does not contain one-hot encoded labels.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=([[0.0], [1.0]], [[0], [1]]),
|
||||
one_hot_labels=True,
|
||||
)
|
||||
|
||||
def test_prototypes1d_init_with_int_dtype(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[[[1], [0]], [1, 0]],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
def test_prototypes1d_inputndim_with_data(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(input_dim=1,
|
||||
nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
data=[[1.0], [1]])
|
||||
|
||||
def test_prototypes1d_inputdim_with_data(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=2,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[[[1.0], [0.0]], [1, 0]],
|
||||
)
|
||||
|
||||
def test_prototypes1d_nclasses_with_data(self):
|
||||
"""Test ValueError raise if provided `nclasses` is not the same
|
||||
as the one computed from the provided `data`.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[[[1.0], [2.0]], [1, 2]],
|
||||
)
|
||||
|
||||
def test_prototypes1d_init_with_ppc(self):
|
||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||
prototypes_per_class=2,
|
||||
prototype_initializer="zeros")
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(4, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_prototypes1d_init_with_pdist(self):
|
||||
p1 = prototypes.Prototypes1D(
|
||||
data=[self.x, self.y],
|
||||
prototype_distribution=[6, 9],
|
||||
prototype_initializer="zeros",
|
||||
)
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(15, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_prototypes1d_func_initializer(self):
|
||||
def my_initializer(*args, **kwargs):
|
||||
return torch.full((2, 99), 99.0), torch.tensor([0, 1])
|
||||
|
||||
p1 = prototypes.Prototypes1D(
|
||||
input_dim=99,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer=my_initializer,
|
||||
)
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = 99 * torch.ones(2, 99)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_prototypes1d_forward(self):
|
||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y])
|
||||
protos, _ = p1()
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.ones(2, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_prototypes1d_dist_validate(self):
|
||||
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = p1._validate_prototype_distribution()
|
||||
|
||||
def test_prototypes1d_validate_extra_repr_not_empty(self):
|
||||
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
||||
rep = p1.extra_repr()
|
||||
self.assertNotEqual(rep, "")
|
||||
|
||||
def tearDown(self):
|
||||
del self.x, self.y, self.gen
|
||||
_ = torch.seed()
|
||||
|
||||
|
||||
class TestLosses(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_glvqloss_init(self):
|
||||
_ = losses.GLVQLoss(0, "swish_beta", beta=20)
|
||||
|
||||
def test_glvqloss_forward_1ppc(self):
|
||||
criterion = losses.GLVQLoss(margin=0,
|
||||
squashing="sigmoid_beta",
|
||||
beta=100)
|
||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||
labels = torch.tensor([0, 1])
|
||||
targets = torch.ones(100)
|
||||
outputs = [d, labels]
|
||||
loss = criterion(outputs, targets)
|
||||
loss_value = loss.item()
|
||||
self.assertAlmostEqual(loss_value, 0.0)
|
||||
|
||||
def test_glvqloss_forward_2ppc(self):
|
||||
criterion = losses.GLVQLoss(margin=0,
|
||||
squashing="sigmoid_beta",
|
||||
beta=100)
|
||||
d = torch.stack([
|
||||
torch.ones(100),
|
||||
torch.ones(100),
|
||||
torch.zeros(100),
|
||||
torch.ones(100)
|
||||
],
|
||||
dim=1)
|
||||
labels = torch.tensor([0, 0, 1, 1])
|
||||
targets = torch.ones(100)
|
||||
outputs = [d, labels]
|
||||
loss = criterion(outputs, targets)
|
||||
loss_value = loss.item()
|
||||
self.assertAlmostEqual(loss_value, 0.0)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
Reference in New Issue
Block a user