Version 0.4.0
This commit is contained in:
commit
3a8388e24f
@ -1,20 +1,11 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.3.0-dev0
|
current_version = 0.4.0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
serialize =
|
serialize =
|
||||||
{major}.{minor}.{patch}-{release}{build}
|
|
||||||
{major}.{minor}.{patch}
|
{major}.{minor}.{patch}
|
||||||
|
|
||||||
[bumpversion:part:release]
|
|
||||||
optional_value = prod
|
|
||||||
first_value = dev
|
|
||||||
values =
|
|
||||||
dev
|
|
||||||
rc
|
|
||||||
prod
|
|
||||||
|
|
||||||
[bumpversion:file:setup.py]
|
[bumpversion:file:setup.py]
|
||||||
|
|
||||||
[bumpversion:file:./prototorch/__init__.py]
|
[bumpversion:file:./prototorch/__init__.py]
|
||||||
|
@ -31,15 +31,15 @@ To also install the extras, use
|
|||||||
pip install -U prototorch[all]
|
pip install -U prototorch[all]
|
||||||
```
|
```
|
||||||
|
|
||||||
*Note: If you're using [ZSH](https://www.zsh.org/), the square brackets `[ ]`
|
*Note: If you're using [ZSH](https://www.zsh.org/) (which is also the default
|
||||||
have to be escaped like so: `\[\]`, making the install command `pip install -U
|
shell on MacOS now), the square brackets `[ ]` have to be escaped like so:
|
||||||
prototorch\[all\]`.*
|
`\[\]`, making the install command `pip install -U prototorch\[all\]`.*
|
||||||
|
|
||||||
To install the bleeding-edge features and improvements:
|
To install the bleeding-edge features and improvements:
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/si-cim/prototorch.git
|
git clone https://github.com/si-cim/prototorch.git
|
||||||
git checkout dev
|
|
||||||
cd prototorch
|
cd prototorch
|
||||||
|
git checkout dev
|
||||||
pip install -e .[all]
|
pip install -e .[all]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -12,9 +12,8 @@
|
|||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
sys.path.insert(0, os.path.abspath("../../"))
|
|
||||||
|
|
||||||
import sphinx_rtd_theme
|
sys.path.insert(0, os.path.abspath("../../"))
|
||||||
|
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
@ -24,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.3.0-dev0"
|
release = "0.4.0"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
@ -128,15 +127,12 @@ latex_elements = {
|
|||||||
# The paper size ("letterpaper" or "a4paper").
|
# The paper size ("letterpaper" or "a4paper").
|
||||||
#
|
#
|
||||||
# "papersize": "letterpaper",
|
# "papersize": "letterpaper",
|
||||||
|
|
||||||
# The font size ("10pt", "11pt" or "12pt").
|
# The font size ("10pt", "11pt" or "12pt").
|
||||||
#
|
#
|
||||||
# "pointsize": "10pt",
|
# "pointsize": "10pt",
|
||||||
|
|
||||||
# Additional stuff for the LaTeX preamble.
|
# Additional stuff for the LaTeX preamble.
|
||||||
#
|
#
|
||||||
# "preamble": "",
|
# "preamble": "",
|
||||||
|
|
||||||
# Latex figure (float) alignment
|
# Latex figure (float) alignment
|
||||||
#
|
#
|
||||||
# "figure_align": "htbp",
|
# "figure_align": "htbp",
|
||||||
@ -146,15 +142,21 @@ latex_elements = {
|
|||||||
# (source start file, target name, title,
|
# (source start file, target name, title,
|
||||||
# author, documentclass [howto, manual, or own class]).
|
# author, documentclass [howto, manual, or own class]).
|
||||||
latex_documents = [
|
latex_documents = [
|
||||||
(master_doc, "prototorch.tex", "ProtoTorch Documentation",
|
(
|
||||||
"Jensun Ravichandran", "manual"),
|
master_doc,
|
||||||
|
"prototorch.tex",
|
||||||
|
"ProtoTorch Documentation",
|
||||||
|
"Jensun Ravichandran",
|
||||||
|
"manual",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# -- Options for manual page output ---------------------------------------
|
# -- Options for manual page output ---------------------------------------
|
||||||
|
|
||||||
# One entry per manual page. List of tuples
|
# One entry per manual page. List of tuples
|
||||||
# (source start file, name, description, authors, manual section).
|
# (source start file, name, description, authors, manual section).
|
||||||
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)]
|
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author],
|
||||||
|
1)]
|
||||||
|
|
||||||
# -- Options for Texinfo output -------------------------------------------
|
# -- Options for Texinfo output -------------------------------------------
|
||||||
|
|
||||||
@ -162,9 +164,15 @@ man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)
|
|||||||
# (source start file, target name, title, author,
|
# (source start file, target name, title, author,
|
||||||
# dir menu entry, description, category)
|
# dir menu entry, description, category)
|
||||||
texinfo_documents = [
|
texinfo_documents = [
|
||||||
(master_doc, "prototorch", "ProtoTorch Documentation", author, "prototorch",
|
(
|
||||||
|
master_doc,
|
||||||
|
"prototorch",
|
||||||
|
"ProtoTorch Documentation",
|
||||||
|
author,
|
||||||
|
"prototorch",
|
||||||
"Prototype-based machine learning in PyTorch.",
|
"Prototype-based machine learning in PyTorch.",
|
||||||
"Miscellaneous"),
|
"Miscellaneous",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Example configuration for intersphinx: refer to the Python standard library.
|
# Example configuration for intersphinx: refer to the Python standard library.
|
||||||
|
@ -3,13 +3,14 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from torchinfo import summary
|
||||||
|
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
from torchinfo import summary
|
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
# Prepare and preprocess the data
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
@ -29,7 +30,8 @@ class Model(torch.nn.Module):
|
|||||||
prototypes_per_class=3,
|
prototypes_per_class=3,
|
||||||
nclasses=3,
|
nclasses=3,
|
||||||
prototype_initializer="stratified_random",
|
prototype_initializer="stratified_random",
|
||||||
data=[x_train, y_train])
|
data=[x_train, y_train],
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.proto_layer.prototypes
|
protos = self.proto_layer.prototypes
|
||||||
@ -61,8 +63,10 @@ for epoch in range(70):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = wtac(dis, plabels)
|
pred = wtac(dis, plabels)
|
||||||
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
||||||
acc = 100. * correct / len(x_train)
|
acc = 100.0 * correct / len(x_train)
|
||||||
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%")
|
print(
|
||||||
|
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
||||||
|
)
|
||||||
|
|
||||||
# Take a gradient descent step
|
# Take a gradient descent step
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -83,13 +87,15 @@ for epoch in range(70):
|
|||||||
ax.set_ylabel("Data dimension 2")
|
ax.set_ylabel("Data dimension 2")
|
||||||
cmap = "viridis"
|
cmap = "viridis"
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||||
ax.scatter(protos[:, 0],
|
ax.scatter(
|
||||||
|
protos[:, 0],
|
||||||
protos[:, 1],
|
protos[:, 1],
|
||||||
c=plabels,
|
c=plabels,
|
||||||
cmap=cmap,
|
cmap=cmap,
|
||||||
edgecolor="k",
|
edgecolor="k",
|
||||||
marker="D",
|
marker="D",
|
||||||
s=50)
|
s=50,
|
||||||
|
)
|
||||||
|
|
||||||
# Paint decision regions
|
# Paint decision regions
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
|
@ -20,11 +20,13 @@ class Model(torch.nn.Module):
|
|||||||
"""GMLVQ model as a siamese network."""
|
"""GMLVQ model as a siamese network."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
x, y = train_data.data, train_data.targets
|
x, y = train_data.data, train_data.targets
|
||||||
self.p1 = Prototypes1D(input_dim=100,
|
self.p1 = Prototypes1D(
|
||||||
|
input_dim=100,
|
||||||
prototypes_per_class=2,
|
prototypes_per_class=2,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototype_initializer="stratified_random",
|
prototype_initializer="stratified_random",
|
||||||
data=[x, y])
|
data=[x, y],
|
||||||
|
)
|
||||||
self.omega = torch.nn.Linear(in_features=100,
|
self.omega = torch.nn.Linear(in_features=100,
|
||||||
out_features=100,
|
out_features=100,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
@ -13,8 +13,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||||
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.models import GTLVQ
|
from prototorch.modules.models import GTLVQ
|
||||||
|
|
||||||
# Parameters and options
|
# Parameters and options
|
||||||
@ -26,32 +27,40 @@ momentum = 0.5
|
|||||||
log_interval = 10
|
log_interval = 10
|
||||||
cuda = "cuda:1"
|
cuda = "cuda:1"
|
||||||
random_seed = 1
|
random_seed = 1
|
||||||
device = torch.device(cuda if torch.cuda.is_available() else 'cpu')
|
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Configures reproducability
|
# Configures reproducability
|
||||||
torch.manual_seed(random_seed)
|
torch.manual_seed(random_seed)
|
||||||
np.random.seed(random_seed)
|
np.random.seed(random_seed)
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
# Prepare and preprocess the data
|
||||||
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
'./files/',
|
torchvision.datasets.MNIST(
|
||||||
|
"./files/",
|
||||||
train=True,
|
train=True,
|
||||||
download=True,
|
download=True,
|
||||||
transform=torchvision.transforms.Compose(
|
transform=torchvision.transforms.Compose([
|
||||||
[transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||||
|
]),
|
||||||
|
),
|
||||||
batch_size=batch_size_train,
|
batch_size=batch_size_train,
|
||||||
shuffle=True)
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
test_loader = torch.utils.data.DataLoader(
|
||||||
'./files/',
|
torchvision.datasets.MNIST(
|
||||||
|
"./files/",
|
||||||
train=False,
|
train=False,
|
||||||
download=True,
|
download=True,
|
||||||
transform=torchvision.transforms.Compose(
|
transform=torchvision.transforms.Compose([
|
||||||
[transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||||
|
]),
|
||||||
|
),
|
||||||
batch_size=batch_size_test,
|
batch_size=batch_size_test,
|
||||||
shuffle=True)
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Define the GLVQ model plus appropriate feature extractor
|
# Define the GLVQ model plus appropriate feature extractor
|
||||||
@ -67,25 +76,34 @@ class CNNGTLVQ(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
super(CNNGTLVQ, self).__init__()
|
super(CNNGTLVQ, self).__init__()
|
||||||
|
|
||||||
#Feature Extractor - Simple CNN
|
# Feature Extractor - Simple CNN
|
||||||
self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
|
self.fe = nn.Sequential(
|
||||||
nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
|
nn.Conv2d(1, 32, 3, 1),
|
||||||
nn.MaxPool2d(2), nn.Dropout(0.25),
|
nn.ReLU(),
|
||||||
nn.Flatten(), nn.Linear(9216, bottleneck_dim),
|
nn.Conv2d(32, 64, 3, 1),
|
||||||
nn.Dropout(0.5), nn.LeakyReLU(),
|
nn.ReLU(),
|
||||||
nn.LayerNorm(bottleneck_dim))
|
nn.MaxPool2d(2),
|
||||||
|
nn.Dropout(0.25),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(9216, bottleneck_dim),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.LayerNorm(bottleneck_dim),
|
||||||
|
)
|
||||||
|
|
||||||
# Forward pass of subspace and prototype initialization data through feature extractor
|
# Forward pass of subspace and prototype initialization data through feature extractor
|
||||||
subspace_data = self.fe(subspace_data)
|
subspace_data = self.fe(subspace_data)
|
||||||
prototype_data[0] = self.fe(prototype_data[0])
|
prototype_data[0] = self.fe(prototype_data[0])
|
||||||
|
|
||||||
# Initialization of GTLVQ
|
# Initialization of GTLVQ
|
||||||
self.gtlvq = GTLVQ(num_classes,
|
self.gtlvq = GTLVQ(
|
||||||
|
num_classes,
|
||||||
subspace_data,
|
subspace_data,
|
||||||
prototype_data,
|
prototype_data,
|
||||||
tangent_projection_type=tangent_projection_type,
|
tangent_projection_type=tangent_projection_type,
|
||||||
feature_dim=bottleneck_dim,
|
feature_dim=bottleneck_dim,
|
||||||
prototypes_per_class=prototypes_per_class)
|
prototypes_per_class=prototypes_per_class,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Feature Extraction
|
# Feature Extraction
|
||||||
@ -103,20 +121,24 @@ subspace_data = torch.cat(
|
|||||||
prototype_data = next(iter(train_loader))
|
prototype_data = next(iter(train_loader))
|
||||||
|
|
||||||
# Build the CNN GTLVQ model
|
# Build the CNN GTLVQ model
|
||||||
model = CNNGTLVQ(10,
|
model = CNNGTLVQ(
|
||||||
|
10,
|
||||||
subspace_data,
|
subspace_data,
|
||||||
prototype_data,
|
prototype_data,
|
||||||
tangent_projection_type="local",
|
tangent_projection_type="local",
|
||||||
bottleneck_dim=128).to(device)
|
bottleneck_dim=128,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
# Optimize using SGD optimizer from `torch.optim`
|
# Optimize using SGD optimizer from `torch.optim`
|
||||||
optimizer = torch.optim.Adam([{
|
optimizer = torch.optim.Adam(
|
||||||
'params': model.fe.parameters()
|
[{
|
||||||
}, {
|
"params": model.fe.parameters()
|
||||||
'params': model.gtlvq.parameters()
|
}, {
|
||||||
}],
|
"params": model.gtlvq.parameters()
|
||||||
lr=learning_rate)
|
}],
|
||||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
lr=learning_rate,
|
||||||
|
)
|
||||||
|
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
for epoch in range(n_epochs):
|
for epoch in range(n_epochs):
|
||||||
@ -139,8 +161,8 @@ for epoch in range(n_epochs):
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
||||||
print(
|
print(
|
||||||
f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
f"Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||||
Train Acc: {acc.item():02.02f}')
|
Train Acc: {acc.item():02.02f}")
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -154,9 +176,9 @@ for epoch in range(n_epochs):
|
|||||||
i = torch.argmin(test_distances, 1)
|
i = torch.argmin(test_distances, 1)
|
||||||
correct += torch.sum(y_test == test_plabels[i])
|
correct += torch.sum(y_test == test_plabels[i])
|
||||||
total += y_test.size(0)
|
total += y_test.size(0)
|
||||||
print('Accuracy of the network on the test images: %d %%' %
|
print("Accuracy of the network on the test images: %d %%" %
|
||||||
(torch.true_divide(correct, total) * 100))
|
(torch.true_divide(correct, total) * 100))
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
PATH = './glvq_mnist_model.pth'
|
PATH = "./glvq_mnist_model.pth"
|
||||||
torch.save(model.state_dict(), PATH)
|
torch.save(model.state_dict(), PATH)
|
||||||
|
@ -22,10 +22,12 @@ class Model(torch.nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Local-GMLVQ model."""
|
"""Local-GMLVQ model."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.p1 = Prototypes1D(input_dim=2,
|
self.p1 = Prototypes1D(
|
||||||
|
input_dim=2,
|
||||||
prototype_distribution=[1, 2, 2],
|
prototype_distribution=[1, 2, 2],
|
||||||
prototype_initializer="stratified_random",
|
prototype_initializer="stratified_random",
|
||||||
data=[x_train, y_train])
|
data=[x_train, y_train],
|
||||||
|
)
|
||||||
omegas = torch.zeros(5, 2, 2)
|
omegas = torch.zeros(5, 2, 2)
|
||||||
self.omegas = torch.nn.Parameter(omegas)
|
self.omegas = torch.nn.Parameter(omegas)
|
||||||
eye_(self.omegas)
|
eye_(self.omegas)
|
||||||
@ -76,14 +78,16 @@ for epoch in range(100):
|
|||||||
ax.set_xlabel("Data dimension 1")
|
ax.set_xlabel("Data dimension 1")
|
||||||
ax.set_ylabel("Data dimension 2")
|
ax.set_ylabel("Data dimension 2")
|
||||||
cmap = "viridis"
|
cmap = "viridis"
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||||
ax.scatter(protos[:, 0],
|
ax.scatter(
|
||||||
|
protos[:, 0],
|
||||||
protos[:, 1],
|
protos[:, 1],
|
||||||
c=plabels,
|
c=plabels,
|
||||||
cmap=cmap,
|
cmap=cmap,
|
||||||
edgecolor='k',
|
edgecolor="k",
|
||||||
marker='D',
|
marker="D",
|
||||||
s=50)
|
s=50,
|
||||||
|
)
|
||||||
|
|
||||||
# Paint decision regions
|
# Paint decision regions
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
|
65
examples/new_components.py
Normal file
65
examples/new_components.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
"""This example script shows the usage of the new components architecture.
|
||||||
|
|
||||||
|
Serialization/deserialization also works as expected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# DATASET
|
||||||
|
import torch
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
scaler = StandardScaler()
|
||||||
|
x_train, y_train = load_iris(return_X_y=True)
|
||||||
|
x_train = x_train[:, [0, 2]]
|
||||||
|
scaler.fit(x_train)
|
||||||
|
x_train = scaler.transform(x_train)
|
||||||
|
|
||||||
|
x_train = torch.Tensor(x_train)
|
||||||
|
y_train = torch.Tensor(y_train)
|
||||||
|
num_classes = len(torch.unique(y_train))
|
||||||
|
|
||||||
|
# CREATE NEW COMPONENTS
|
||||||
|
from prototorch.components import *
|
||||||
|
from prototorch.components.initializers import *
|
||||||
|
|
||||||
|
unsupervised = Components(6, SelectionInitializer(x_train))
|
||||||
|
print(unsupervised())
|
||||||
|
|
||||||
|
prototypes = LabeledComponents(
|
||||||
|
(3, 2), StratifiedSelectionInitializer(x_train, y_train))
|
||||||
|
print(prototypes())
|
||||||
|
|
||||||
|
components = ReasoningComponents(
|
||||||
|
(3, 6), StratifiedSelectionInitializer(x_train, y_train))
|
||||||
|
print(components())
|
||||||
|
|
||||||
|
# TEST SERIALIZATION
|
||||||
|
import io
|
||||||
|
|
||||||
|
save = io.BytesIO()
|
||||||
|
torch.save(unsupervised, save)
|
||||||
|
save.seek(0)
|
||||||
|
serialized_unsupervised = torch.load(save)
|
||||||
|
|
||||||
|
assert torch.all(unsupervised.components == serialized_unsupervised.components
|
||||||
|
), "Serialization of Components failed."
|
||||||
|
|
||||||
|
save = io.BytesIO()
|
||||||
|
torch.save(prototypes, save)
|
||||||
|
save.seek(0)
|
||||||
|
serialized_prototypes = torch.load(save)
|
||||||
|
|
||||||
|
assert torch.all(prototypes.components == serialized_prototypes.components
|
||||||
|
), "Serialization of Components failed."
|
||||||
|
assert torch.all(prototypes.component_labels == serialized_prototypes.
|
||||||
|
component_labels), "Serialization of Components failed."
|
||||||
|
|
||||||
|
save = io.BytesIO()
|
||||||
|
torch.save(components, save)
|
||||||
|
save.seek(0)
|
||||||
|
serialized_components = torch.load(save)
|
||||||
|
|
||||||
|
assert torch.all(components.components == serialized_components.components
|
||||||
|
), "Serialization of Components failed."
|
||||||
|
assert torch.all(components.reasonings == serialized_components.reasonings
|
||||||
|
), "Serialization of Components failed."
|
@ -1,11 +1,7 @@
|
|||||||
"""ProtoTorch package."""
|
"""ProtoTorch package."""
|
||||||
|
|
||||||
# #############################################
|
|
||||||
# Core Setup
|
# Core Setup
|
||||||
# #############################################
|
__version__ = "0.4.0"
|
||||||
__version__ = "0.3.0-dev0"
|
|
||||||
|
|
||||||
from prototorch import datasets, functions, modules
|
|
||||||
|
|
||||||
__all_core__ = [
|
__all_core__ = [
|
||||||
"datasets",
|
"datasets",
|
||||||
@ -13,10 +9,11 @@ __all_core__ = [
|
|||||||
"modules",
|
"modules",
|
||||||
]
|
]
|
||||||
|
|
||||||
# #############################################
|
from .datasets import *
|
||||||
|
|
||||||
# Plugin Loader
|
# Plugin Loader
|
||||||
# #############################################
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__)
|
__path__ = pkgutil.extend_path(__path__, __name__)
|
||||||
@ -25,7 +22,8 @@ __path__ = pkgutil.extend_path(__path__, __name__)
|
|||||||
def discover_plugins():
|
def discover_plugins():
|
||||||
return {
|
return {
|
||||||
entry_point.name: entry_point.load()
|
entry_point.name: entry_point.load()
|
||||||
for entry_point in pkg_resources.iter_entry_points("prototorch.plugins")
|
for entry_point in pkg_resources.iter_entry_points(
|
||||||
|
"prototorch.plugins")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -33,12 +31,10 @@ discovered_plugins = discover_plugins()
|
|||||||
locals().update(discovered_plugins)
|
locals().update(discovered_plugins)
|
||||||
|
|
||||||
# Generate combines __version__ and __all__
|
# Generate combines __version__ and __all__
|
||||||
version_plugins = "\n".join(
|
version_plugins = "\n".join([
|
||||||
[
|
|
||||||
"- " + name + ": v" + plugin.__version__
|
"- " + name + ": v" + plugin.__version__
|
||||||
for name, plugin in discovered_plugins.items()
|
for name, plugin in discovered_plugins.items()
|
||||||
]
|
])
|
||||||
)
|
|
||||||
if version_plugins != "":
|
if version_plugins != "":
|
||||||
version_plugins = "\nPlugins: \n" + version_plugins
|
version_plugins = "\nPlugins: \n" + version_plugins
|
||||||
|
|
||||||
|
2
prototorch/components/__init__.py
Normal file
2
prototorch/components/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from prototorch.components.components import *
|
||||||
|
from prototorch.components.initializers import *
|
134
prototorch/components/components.py
Normal file
134
prototorch/components/components.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
"""ProtoTorch components modules."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from prototorch.components.initializers import (ComponentsInitializer,
|
||||||
|
EqualLabelInitializer,
|
||||||
|
ZeroReasoningsInitializer)
|
||||||
|
from prototorch.functions.initializers import get_initializer
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
|
||||||
|
class Components(torch.nn.Module):
|
||||||
|
"""Components is a set of learnable Tensors."""
|
||||||
|
def __init__(self,
|
||||||
|
number_of_components=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_components=None,
|
||||||
|
dtype=torch.float32):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Ignore all initialization settings if initialized_components is given.
|
||||||
|
if initialized_components is not None:
|
||||||
|
self._components = Parameter(initialized_components)
|
||||||
|
if number_of_components is not None or initializer is not None:
|
||||||
|
wmsg = "Arguments ignored while initializing Components"
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
else:
|
||||||
|
self._initialize_components(number_of_components, initializer)
|
||||||
|
|
||||||
|
def _initialize_components(self, number_of_components, initializer):
|
||||||
|
if not isinstance(initializer, ComponentsInitializer):
|
||||||
|
emsg = f"`initializer` has to be some subtype of " \
|
||||||
|
f"{ComponentsInitializer}. " \
|
||||||
|
f"You have provided: {initializer=} instead."
|
||||||
|
raise TypeError(emsg)
|
||||||
|
self._components = Parameter(
|
||||||
|
initializer.generate(number_of_components))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def components(self):
|
||||||
|
"""Tensor containing the component tensors."""
|
||||||
|
return self._components.detach().cpu()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self._components
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"components.shape: {tuple(self._components.shape)}"
|
||||||
|
|
||||||
|
|
||||||
|
class LabeledComponents(Components):
|
||||||
|
"""LabeledComponents generate a set of components and a set of labels.
|
||||||
|
|
||||||
|
Every Component has a label assigned.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
labels=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_components=None):
|
||||||
|
if initialized_components is not None:
|
||||||
|
super().__init__(initialized_components=initialized_components[0])
|
||||||
|
self._labels = initialized_components[1]
|
||||||
|
else:
|
||||||
|
self._initialize_labels(labels)
|
||||||
|
super().__init__(number_of_components=len(self._labels),
|
||||||
|
initializer=initializer)
|
||||||
|
|
||||||
|
def _initialize_labels(self, labels):
|
||||||
|
if type(labels) == tuple:
|
||||||
|
num_classes, prototypes_per_class = labels
|
||||||
|
labels = EqualLabelInitializer(num_classes, prototypes_per_class)
|
||||||
|
|
||||||
|
self._labels = labels.generate()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def component_labels(self):
|
||||||
|
"""Tensor containing the component tensors."""
|
||||||
|
return self._labels.detach().cpu()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return super().forward(), self._labels
|
||||||
|
|
||||||
|
|
||||||
|
class ReasoningComponents(Components):
|
||||||
|
"""ReasoningComponents generate a set of components and a set of reasoning matrices.
|
||||||
|
|
||||||
|
Every Component has a reasoning matrix assigned.
|
||||||
|
|
||||||
|
A reasoning matrix is a Nx2 matrix, where N is the number of Classes. The
|
||||||
|
first element is called positive reasoning :math:`p`, the second negative
|
||||||
|
reasoning :math:`n`. A components can reason in favour (positive) of a
|
||||||
|
class, against (negative) a class or not at all (neutral).
|
||||||
|
|
||||||
|
It holds that :math:`0 \leq n \leq 1`, :math:`0 \leq p \leq 1` and :math:`0
|
||||||
|
\leq n+p \leq 1`. Therefore :math:`n` and :math:`p` are two elements of a
|
||||||
|
three element probability distribution.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
reasonings=None,
|
||||||
|
initializer=None,
|
||||||
|
*,
|
||||||
|
initialized_components=None):
|
||||||
|
if initialized_components is not None:
|
||||||
|
super().__init__(initialized_components=initialized_components[0])
|
||||||
|
self._reasonings = initialized_components[1]
|
||||||
|
else:
|
||||||
|
self._initialize_reasonings(reasonings)
|
||||||
|
super().__init__(number_of_components=len(self._reasonings),
|
||||||
|
initializer=initializer)
|
||||||
|
|
||||||
|
def _initialize_reasonings(self, reasonings):
|
||||||
|
if type(reasonings) == tuple:
|
||||||
|
num_classes, number_of_components = reasonings
|
||||||
|
reasonings = ZeroReasoningsInitializer(num_classes,
|
||||||
|
number_of_components)
|
||||||
|
|
||||||
|
self._reasonings = reasonings.generate()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reasonings(self):
|
||||||
|
"""Returns Reasoning Matrix.
|
||||||
|
|
||||||
|
Dimension NxCx2
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self._reasonings.detach().cpu()
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return super().forward(), self._reasonings
|
172
prototorch/components/initializers.py
Normal file
172
prototorch/components/initializers.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
"""ProtoTroch Initializers."""
|
||||||
|
import warnings
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def parse_init_arg(arg):
|
||||||
|
if isinstance(arg, Dataset):
|
||||||
|
data, labels = next(iter(DataLoader(arg, batch_size=len(arg))))
|
||||||
|
# data = data.view(len(arg), -1) # flatten
|
||||||
|
else:
|
||||||
|
data, labels = arg
|
||||||
|
if not isinstance(data, torch.Tensor):
|
||||||
|
wmsg = f"Converting data to {torch.Tensor}."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
data = torch.Tensor(data)
|
||||||
|
if not isinstance(labels, torch.Tensor):
|
||||||
|
wmsg = f"Converting labels to {torch.Tensor}."
|
||||||
|
warnings.warn(wmsg)
|
||||||
|
labels = torch.Tensor(labels)
|
||||||
|
return data, labels
|
||||||
|
|
||||||
|
|
||||||
|
# Components
|
||||||
|
class ComponentsInitializer(object):
|
||||||
|
def generate(self, number_of_components):
|
||||||
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
|
class DimensionAwareInitializer(ComponentsInitializer):
|
||||||
|
def __init__(self, c_dims):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(c_dims, Iterable):
|
||||||
|
self.components_dims = tuple(c_dims)
|
||||||
|
else:
|
||||||
|
self.components_dims = (c_dims, )
|
||||||
|
|
||||||
|
|
||||||
|
class OnesInitializer(DimensionAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
gen_dims = (length, ) + self.components_dims
|
||||||
|
return torch.ones(gen_dims)
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosInitializer(DimensionAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
gen_dims = (length, ) + self.components_dims
|
||||||
|
return torch.zeros(gen_dims)
|
||||||
|
|
||||||
|
|
||||||
|
class UniformInitializer(DimensionAwareInitializer):
|
||||||
|
def __init__(self, c_dims, min=0.0, max=1.0):
|
||||||
|
super().__init__(c_dims)
|
||||||
|
|
||||||
|
self.min = min
|
||||||
|
self.max = max
|
||||||
|
|
||||||
|
def generate(self, length):
|
||||||
|
gen_dims = (length, ) + self.components_dims
|
||||||
|
return torch.ones(gen_dims).uniform_(self.min, self.max)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionAwareInitializer(ComponentsInitializer):
|
||||||
|
def __init__(self, positions):
|
||||||
|
super().__init__()
|
||||||
|
self.data = positions
|
||||||
|
|
||||||
|
|
||||||
|
class SelectionInitializer(PositionAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||||
|
return self.data[indices]
|
||||||
|
|
||||||
|
|
||||||
|
class MeanInitializer(PositionAwareInitializer):
|
||||||
|
def generate(self, length):
|
||||||
|
mean = torch.mean(self.data, dim=0)
|
||||||
|
repeat_dim = [length] + [1] * len(mean.shape)
|
||||||
|
return mean.repeat(repeat_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassAwareInitializer(ComponentsInitializer):
|
||||||
|
def __init__(self, arg):
|
||||||
|
super().__init__()
|
||||||
|
data, labels = parse_init_arg(arg)
|
||||||
|
self.data = data
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
self.clabels = torch.unique(self.labels)
|
||||||
|
self.num_classes = len(self.clabels)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedMeanInitializer(ClassAwareInitializer):
|
||||||
|
def __init__(self, arg):
|
||||||
|
super().__init__(arg)
|
||||||
|
|
||||||
|
self.initializers = []
|
||||||
|
for clabel in self.clabels:
|
||||||
|
class_data = self.data[self.labels == clabel]
|
||||||
|
class_initializer = MeanInitializer(class_data)
|
||||||
|
self.initializers.append(class_initializer)
|
||||||
|
|
||||||
|
def generate(self, length):
|
||||||
|
per_class = length // self.num_classes
|
||||||
|
samples_list = [init.generate(per_class) for init in self.initializers]
|
||||||
|
return torch.vstack(samples_list)
|
||||||
|
|
||||||
|
|
||||||
|
class StratifiedSelectionInitializer(ClassAwareInitializer):
|
||||||
|
def __init__(self, arg, *, noise=None):
|
||||||
|
super().__init__(arg)
|
||||||
|
self.noise = noise
|
||||||
|
|
||||||
|
self.initializers = []
|
||||||
|
for clabel in self.clabels:
|
||||||
|
class_data = self.data[self.labels == clabel]
|
||||||
|
class_initializer = SelectionInitializer(class_data)
|
||||||
|
self.initializers.append(class_initializer)
|
||||||
|
|
||||||
|
def add_noise(self, x):
|
||||||
|
"""Shifts some dimensions of the data randomly."""
|
||||||
|
n1 = torch.rand_like(x)
|
||||||
|
n2 = torch.rand_like(x)
|
||||||
|
mask = torch.bernoulli(n1) - torch.bernoulli(n2)
|
||||||
|
return x + (self.noise * mask)
|
||||||
|
|
||||||
|
def generate(self, length):
|
||||||
|
per_class = length // self.num_classes
|
||||||
|
samples_list = [init.generate(per_class) for init in self.initializers]
|
||||||
|
samples = torch.vstack(samples_list)
|
||||||
|
if self.noise is not None:
|
||||||
|
# samples = self.add_noise(samples)
|
||||||
|
samples = samples + self.noise
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
# Labels
|
||||||
|
class LabelsInitializer:
|
||||||
|
def generate(self):
|
||||||
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
|
class EqualLabelInitializer(LabelsInitializer):
|
||||||
|
def __init__(self, classes, per_class):
|
||||||
|
self.classes = classes
|
||||||
|
self.per_class = per_class
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
||||||
|
|
||||||
|
|
||||||
|
# Reasonings
|
||||||
|
class ReasoningsInitializer:
|
||||||
|
def generate(self, length):
|
||||||
|
raise NotImplementedError("Subclasses should implement this!")
|
||||||
|
|
||||||
|
|
||||||
|
class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||||
|
def __init__(self, classes, length):
|
||||||
|
self.classes = classes
|
||||||
|
self.length = length
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
return torch.zeros((self.length, self.classes, 2))
|
||||||
|
|
||||||
|
|
||||||
|
# Aliases
|
||||||
|
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
||||||
|
SMI = StratifiedMeanInitializer
|
||||||
|
Random = RandomInitializer = UniformInitializer
|
@ -1,7 +1,11 @@
|
|||||||
"""ProtoTorch datasets."""
|
"""ProtoTorch datasets."""
|
||||||
|
|
||||||
|
from .abstract import NumpyDataset
|
||||||
|
from .spiral import Spiral
|
||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Tecator',
|
"NumpyDataset",
|
||||||
|
"Spiral",
|
||||||
|
"Tecator",
|
||||||
]
|
]
|
||||||
|
@ -12,6 +12,13 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||||
|
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||||
|
def __init__(self, *arrays):
|
||||||
|
tensors = [torch.Tensor(arr) for arr in arrays]
|
||||||
|
super().__init__(*tensors)
|
||||||
|
|
||||||
|
|
||||||
class Dataset(torch.utils.data.Dataset):
|
class Dataset(torch.utils.data.Dataset):
|
||||||
"""Abstract dataset class to be inherited."""
|
"""Abstract dataset class to be inherited."""
|
||||||
|
|
||||||
@ -44,15 +51,13 @@ class ProtoDataset(Dataset):
|
|||||||
self._download()
|
self._download()
|
||||||
|
|
||||||
if not self._check_exists():
|
if not self._check_exists():
|
||||||
raise RuntimeError(
|
raise RuntimeError("Dataset not found. "
|
||||||
"Dataset not found. " "You can use download=True to download it"
|
"You can use download=True to download it")
|
||||||
)
|
|
||||||
|
|
||||||
data_file = self.training_file if self.train else self.test_file
|
data_file = self.training_file if self.train else self.test_file
|
||||||
|
|
||||||
self.data, self.targets = torch.load(
|
self.data, self.targets = torch.load(
|
||||||
os.path.join(self.processed_folder, data_file)
|
os.path.join(self.processed_folder, data_file))
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def raw_folder(self):
|
def raw_folder(self):
|
||||||
@ -68,8 +73,9 @@ class ProtoDataset(Dataset):
|
|||||||
|
|
||||||
def _check_exists(self):
|
def _check_exists(self):
|
||||||
return os.path.exists(
|
return os.path.exists(
|
||||||
os.path.join(self.processed_folder, self.training_file)
|
os.path.join(
|
||||||
) and os.path.exists(os.path.join(self.processed_folder, self.test_file))
|
self.processed_folder, self.training_file)) and os.path.exists(
|
||||||
|
os.path.join(self.processed_folder, self.test_file))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
head = "Dataset " + self.__class__.__name__
|
head = "Dataset " + self.__class__.__name__
|
||||||
|
33
prototorch/datasets/spiral.py
Normal file
33
prototorch/datasets/spiral.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
"""Spiral dataset for binary classification."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_spiral(n_samples=500, noise=0.3):
|
||||||
|
def get_samples(n, delta_t):
|
||||||
|
points = []
|
||||||
|
for i in range(n):
|
||||||
|
r = i / n_samples * 5
|
||||||
|
t = 1.75 * i / n * 2 * np.pi + delta_t
|
||||||
|
x = r * np.sin(t) + np.random.rand(1) * noise
|
||||||
|
y = r * np.cos(t) + np.random.rand(1) * noise
|
||||||
|
points.append([x, y])
|
||||||
|
return points
|
||||||
|
|
||||||
|
n = n_samples // 2
|
||||||
|
positive = get_samples(n=n, delta_t=0)
|
||||||
|
negative = get_samples(n=n, delta_t=np.pi)
|
||||||
|
x = np.concatenate(
|
||||||
|
[np.array(positive).reshape(n, -1),
|
||||||
|
np.array(negative).reshape(n, -1)],
|
||||||
|
axis=0)
|
||||||
|
y = np.concatenate([np.zeros(n), np.ones(n)])
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class Spiral(torch.utils.data.TensorDataset):
|
||||||
|
"""Spiral dataset for binary classification."""
|
||||||
|
def __init__(self, n_samples=500, noise=0.3):
|
||||||
|
x, y = make_spiral(n_samples, noise)
|
||||||
|
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
@ -52,7 +52,8 @@ class Tecator(ProtoDataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_resources = [
|
_resources = [
|
||||||
("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0", "ba5607c580d0f91bb27dc29d13c2f8df"),
|
("1P9WIYnyxFPh6f1vqAbnKfK8oYmUgyV83",
|
||||||
|
"ba5607c580d0f91bb27dc29d13c2f8df"),
|
||||||
] # (google_storage_id, md5hash)
|
] # (google_storage_id, md5hash)
|
||||||
classes = ["0 - low_fat", "1 - high_fat"]
|
classes = ["0 - low_fat", "1 - high_fat"]
|
||||||
|
|
||||||
@ -74,15 +75,15 @@ class Tecator(ProtoDataset):
|
|||||||
print("Downloading...")
|
print("Downloading...")
|
||||||
for fileid, md5 in self._resources:
|
for fileid, md5 in self._resources:
|
||||||
filename = "tecator.npz"
|
filename = "tecator.npz"
|
||||||
download_file_from_google_drive(
|
download_file_from_google_drive(fileid,
|
||||||
fileid, root=self.raw_folder, filename=filename, md5=md5
|
root=self.raw_folder,
|
||||||
)
|
filename=filename,
|
||||||
|
md5=md5)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Processing...")
|
print("Processing...")
|
||||||
with np.load(
|
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
||||||
os.path.join(self.raw_folder, "tecator.npz"), allow_pickle=False
|
allow_pickle=False) as f:
|
||||||
) as f:
|
|
||||||
x_train, y_train = f["x_train"], f["y_train"]
|
x_train, y_train = f["x_train"], f["y_train"]
|
||||||
x_test, y_test = f["x_test"], f["y_test"]
|
x_test, y_test = f["x_test"], f["y_test"]
|
||||||
training_set = [
|
training_set = [
|
||||||
@ -94,9 +95,11 @@ class Tecator(ProtoDataset):
|
|||||||
torch.tensor(y_test),
|
torch.tensor(y_test),
|
||||||
]
|
]
|
||||||
|
|
||||||
with open(os.path.join(self.processed_folder, self.training_file), "wb") as f:
|
with open(os.path.join(self.processed_folder, self.training_file),
|
||||||
|
"wb") as f:
|
||||||
torch.save(training_set, f)
|
torch.save(training_set, f)
|
||||||
with open(os.path.join(self.processed_folder, self.test_file), "wb") as f:
|
with open(os.path.join(self.processed_folder, self.test_file),
|
||||||
|
"wb") as f:
|
||||||
torch.save(test_set, f)
|
torch.save(test_set, f)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -4,9 +4,9 @@ from .activations import identity, sigmoid_beta, swish_beta
|
|||||||
from .competitions import knnc, wtac
|
from .competitions import knnc, wtac
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'identity',
|
"identity",
|
||||||
'sigmoid_beta',
|
"sigmoid_beta",
|
||||||
'swish_beta',
|
"swish_beta",
|
||||||
'knnc',
|
"knnc",
|
||||||
'wtac',
|
"wtac",
|
||||||
]
|
]
|
||||||
|
@ -16,40 +16,43 @@ def register_activation(function):
|
|||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
# @torch.jit.script
|
# @torch.jit.script
|
||||||
def identity(x, beta=torch.tensor(0)):
|
def identity(x, beta=0.0):
|
||||||
"""Identity activation function.
|
"""Identity activation function.
|
||||||
|
|
||||||
Definition:
|
Definition:
|
||||||
:math:`f(x) = x`
|
:math:`f(x) = x`
|
||||||
|
|
||||||
|
Keyword Arguments:
|
||||||
|
beta (`float`): Ignored.
|
||||||
"""
|
"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
# @torch.jit.script
|
# @torch.jit.script
|
||||||
def sigmoid_beta(x, beta=torch.tensor(10)):
|
def sigmoid_beta(x, beta=10.0):
|
||||||
r"""Sigmoid activation function with scaling.
|
r"""Sigmoid activation function with scaling.
|
||||||
|
|
||||||
Definition:
|
Definition:
|
||||||
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
|
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
beta (`torch.tensor`): Scaling parameter :math:`\beta`
|
beta (`float`): Scaling parameter :math:`\beta`
|
||||||
"""
|
"""
|
||||||
out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * x))
|
out = 1.0 / (1.0 + torch.exp(-1.0 * beta * x))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
# @torch.jit.script
|
# @torch.jit.script
|
||||||
def swish_beta(x, beta=torch.tensor(10)):
|
def swish_beta(x, beta=10.0):
|
||||||
r"""Swish activation function with scaling.
|
r"""Swish activation function with scaling.
|
||||||
|
|
||||||
Definition:
|
Definition:
|
||||||
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
|
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
beta (`torch.tensor`): Scaling parameter :math:`\beta`
|
beta (`float`): Scaling parameter :math:`\beta`
|
||||||
"""
|
"""
|
||||||
out = x * sigmoid_beta(x, beta=beta)
|
out = x * sigmoid_beta(x, beta=beta)
|
||||||
return out
|
return out
|
||||||
@ -61,4 +64,4 @@ def get_activation(funcname):
|
|||||||
return funcname
|
return funcname
|
||||||
if funcname in ACTIVATIONS:
|
if funcname in ACTIVATIONS:
|
||||||
return ACTIVATIONS.get(funcname)
|
return ACTIVATIONS.get(funcname)
|
||||||
raise NameError(f'Activation {funcname} was not found.')
|
raise NameError(f"Activation {funcname} was not found.")
|
||||||
|
@ -12,7 +12,7 @@ def stratified_min(distances, labels):
|
|||||||
return distances
|
return distances
|
||||||
batch_size = distances.size()[0]
|
batch_size = distances.size()[0]
|
||||||
winning_distances = torch.zeros(nclasses, batch_size)
|
winning_distances = torch.zeros(nclasses, batch_size)
|
||||||
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
inf = torch.full_like(distances.T, fill_value=float("inf"))
|
||||||
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||||
for i, cl in enumerate(clabels):
|
for i, cl in enumerate(clabels):
|
||||||
# cdists = distances.T[labels == cl]
|
# cdists = distances.T[labels == cl]
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
"""ProtoTorch distance functions."""
|
"""ProtoTorch distance functions."""
|
||||||
|
|
||||||
import torch
|
|
||||||
from prototorch.functions.helper import (
|
|
||||||
equal_int_shape,
|
|
||||||
_int_and_mixed_shape,
|
|
||||||
_check_shapes,
|
|
||||||
)
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||||
|
equal_int_shape)
|
||||||
|
|
||||||
|
|
||||||
def squared_euclidean_distance(x, y):
|
def squared_euclidean_distance(x, y):
|
||||||
@ -43,9 +41,21 @@ def euclidean_distance(x, y):
|
|||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
def euclidean_distance_v2(x, y):
|
||||||
|
diff = y - x.unsqueeze(1)
|
||||||
|
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||||
|
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||||
|
# batch diagonal. See:
|
||||||
|
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||||
|
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||||
|
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
||||||
|
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
||||||
|
# print(f"{distances.shape=}") # (nx, ny)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
def lpnorm_distance(x, y, p):
|
def lpnorm_distance(x, y, p):
|
||||||
r"""
|
r"""Calculate the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
||||||
Calculates the lp-norm between :math:`\bm x` and :math:`\bm y`.
|
|
||||||
Also known as Minkowski distance.
|
Also known as Minkowski distance.
|
||||||
|
|
||||||
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
Compute :math:`{\| \bm x - \bm y \|}_p`.
|
||||||
@ -88,7 +98,7 @@ def lomega_distance(x, y, omegas):
|
|||||||
projected_y = torch.diagonal(y @ omegas).T
|
projected_y = torch.diagonal(y @ omegas).T
|
||||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||||
batchwise_difference = expanded_y - projected_x
|
batchwise_difference = expanded_y - projected_x
|
||||||
differences_squared = batchwise_difference ** 2
|
differences_squared = batchwise_difference**2
|
||||||
distances = torch.sum(differences_squared, dim=2)
|
distances = torch.sum(differences_squared, dim=2)
|
||||||
distances = distances.permute(1, 0)
|
distances = distances.permute(1, 0)
|
||||||
return distances
|
return distances
|
||||||
@ -107,26 +117,18 @@ def euclidean_distance_matrix(x, y, squared=False, epsilon=1e-10):
|
|||||||
for tensor in [x, y]:
|
for tensor in [x, y]:
|
||||||
if tensor.ndim != 2:
|
if tensor.ndim != 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The tensor dimension must be two. You provide: tensor.ndim="
|
"The tensor dimension must be two. You provide: tensor.ndim=" +
|
||||||
+ str(tensor.ndim)
|
str(tensor.ndim) + ".")
|
||||||
+ "."
|
|
||||||
)
|
|
||||||
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
if not equal_int_shape([tuple(x.shape)[1]], [tuple(y.shape)[1]]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
"The vector shape must be equivalent in both tensors. You provide: tuple(y.shape)[1]="
|
||||||
+ str(tuple(x.shape)[1])
|
+ str(tuple(x.shape)[1]) + " and tuple(y.shape)(y)[1]=" +
|
||||||
+ " and tuple(y.shape)(y)[1]="
|
str(tuple(y.shape)[1]) + ".")
|
||||||
+ str(tuple(y.shape)[1])
|
|
||||||
+ "."
|
|
||||||
)
|
|
||||||
|
|
||||||
y = torch.transpose(y)
|
y = torch.transpose(y)
|
||||||
|
|
||||||
diss = (
|
diss = (torch.sum(x**2, axis=1, keepdims=True) - 2 * torch.dot(x, y) +
|
||||||
torch.sum(x ** 2, axis=1, keepdims=True)
|
torch.sum(y**2, axis=0, keepdims=True))
|
||||||
- 2 * torch.dot(x, y)
|
|
||||||
+ torch.sum(y ** 2, axis=0, keepdims=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not squared:
|
if not squared:
|
||||||
if epsilon == 0:
|
if epsilon == 0:
|
||||||
@ -173,19 +175,18 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
if subspaces.ndim == 2:
|
if subspaces.ndim == 2:
|
||||||
# clean solution without map if the matrix_scope is global
|
# clean solution without map if the matrix_scope is global
|
||||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
projectors = torch.eye(subspace_int_shape[-2]) - torch.dot(
|
||||||
subspaces, torch.transpose(subspaces)
|
subspaces, torch.transpose(subspaces))
|
||||||
)
|
|
||||||
|
|
||||||
projected_signals = torch.dot(signals, projectors)
|
projected_signals = torch.dot(signals, projectors)
|
||||||
projected_protos = torch.dot(protos, projectors)
|
projected_protos = torch.dot(protos, projectors)
|
||||||
|
|
||||||
diss = euclidean_distance_matrix(
|
diss = euclidean_distance_matrix(projected_signals,
|
||||||
projected_signals, projected_protos, squared=squared, epsilon=epsilon
|
projected_protos,
|
||||||
)
|
squared=squared,
|
||||||
|
epsilon=epsilon)
|
||||||
|
|
||||||
diss = torch.reshape(
|
diss = torch.reshape(
|
||||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
|
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||||
)
|
|
||||||
|
|
||||||
return torch.permute(diss, [0, 2, 1])
|
return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
@ -193,21 +194,18 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# no solution without map possible --> memory efficient but slow!
|
# no solution without map possible --> memory efficient but slow!
|
||||||
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
projectors = torch.eye(subspace_int_shape[-2]) - torch.bmm(
|
||||||
subspaces, subspaces
|
subspaces,
|
||||||
) # K.batch_dot(subspaces, subspaces, [2, 2])
|
subspaces) # K.batch_dot(subspaces, subspaces, [2, 2])
|
||||||
|
|
||||||
projected_protos = (
|
projected_protos = (protos @ subspaces
|
||||||
protos @ subspaces
|
|
||||||
).T # K.batch_dot(projectors, protos, [1, 1]))
|
).T # K.batch_dot(projectors, protos, [1, 1]))
|
||||||
|
|
||||||
def projected_norm(projector):
|
def projected_norm(projector):
|
||||||
return torch.sum(torch.dot(signals, projector) ** 2, axis=1)
|
return torch.sum(torch.dot(signals, projector)**2, axis=1)
|
||||||
|
|
||||||
diss = (
|
diss = (torch.transpose(map(projected_norm, projectors)) -
|
||||||
torch.transpose(map(projected_norm, projectors))
|
2 * torch.dot(signals, projected_protos) +
|
||||||
- 2 * torch.dot(signals, projected_protos)
|
torch.sum(projected_protos**2, axis=0, keepdims=True))
|
||||||
+ torch.sum(projected_protos ** 2, axis=0, keepdims=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not squared:
|
if not squared:
|
||||||
if epsilon == 0:
|
if epsilon == 0:
|
||||||
@ -216,8 +214,7 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
diss = torch.sqrt(torch.max(diss, epsilon))
|
diss = torch.sqrt(torch.max(diss, epsilon))
|
||||||
|
|
||||||
diss = torch.reshape(
|
diss = torch.reshape(
|
||||||
diss, [signal_shape[0], signal_shape[2], proto_shape[0]]
|
diss, [signal_shape[0], signal_shape[2], proto_shape[0]])
|
||||||
)
|
|
||||||
|
|
||||||
return torch.permute(diss, [0, 2, 1])
|
return torch.permute(diss, [0, 2, 1])
|
||||||
|
|
||||||
@ -233,12 +230,12 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# Scope: Tangentspace Projections
|
# Scope: Tangentspace Projections
|
||||||
diff = torch.reshape(
|
diff = torch.reshape(
|
||||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
|
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||||
)
|
|
||||||
projected_diff = diff @ projectors
|
projected_diff = diff @ projectors
|
||||||
projected_diff = torch.reshape(
|
projected_diff = torch.reshape(
|
||||||
projected_diff,
|
projected_diff,
|
||||||
(signal_shape[0], signal_shape[2], signal_shape[1]) + signal_shape[3:],
|
(signal_shape[0], signal_shape[2], signal_shape[1]) +
|
||||||
|
signal_shape[3:],
|
||||||
)
|
)
|
||||||
|
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
@ -251,13 +248,13 @@ def tangent_distance(signals, protos, subspaces, squared=False, epsilon=1e-10):
|
|||||||
|
|
||||||
# Scope: Tangentspace Projections
|
# Scope: Tangentspace Projections
|
||||||
diff = torch.reshape(
|
diff = torch.reshape(
|
||||||
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1)
|
diff, (signal_shape[0] * signal_shape[2], signal_shape[1], -1))
|
||||||
)
|
|
||||||
diff = diff.permute([1, 0, 2])
|
diff = diff.permute([1, 0, 2])
|
||||||
projected_diff = torch.bmm(diff, projectors)
|
projected_diff = torch.bmm(diff, projectors)
|
||||||
projected_diff = torch.reshape(
|
projected_diff = torch.reshape(
|
||||||
projected_diff,
|
projected_diff,
|
||||||
(signal_shape[1], signal_shape[0], signal_shape[2]) + signal_shape[3:],
|
(signal_shape[1], signal_shape[0], signal_shape[2]) +
|
||||||
|
signal_shape[3:],
|
||||||
)
|
)
|
||||||
|
|
||||||
diss = torch.norm(projected_diff, 2, dim=-1)
|
diss = torch.norm(projected_diff, 2, dim=-1)
|
||||||
|
@ -23,7 +23,7 @@ def predict_label(y_pred, plabels):
|
|||||||
|
|
||||||
def mixed_shape(inputs):
|
def mixed_shape(inputs):
|
||||||
if not torch.is_tensor(inputs):
|
if not torch.is_tensor(inputs):
|
||||||
raise ValueError('Input must be a tensor.')
|
raise ValueError("Input must be a tensor.")
|
||||||
else:
|
else:
|
||||||
int_shape = list(inputs.shape)
|
int_shape = list(inputs.shape)
|
||||||
# sometimes int_shape returns mixed integer types
|
# sometimes int_shape returns mixed integer types
|
||||||
@ -39,11 +39,11 @@ def mixed_shape(inputs):
|
|||||||
def equal_int_shape(shape_1, shape_2):
|
def equal_int_shape(shape_1, shape_2):
|
||||||
if not isinstance(shape_1,
|
if not isinstance(shape_1,
|
||||||
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
||||||
raise ValueError('Input shapes must list or tuple.')
|
raise ValueError("Input shapes must list or tuple.")
|
||||||
for shape in [shape_1, shape_2]:
|
for shape in [shape_1, shape_2]:
|
||||||
if not all([isinstance(x, int) or x is None for x in shape]):
|
if not all([isinstance(x, int) or x is None for x in shape]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Input shapes must be list or tuple of int and None values.')
|
"Input shapes must be list or tuple of int and None values.")
|
||||||
|
|
||||||
if len(shape_1) != len(shape_2):
|
if len(shape_1) != len(shape_2):
|
||||||
return False
|
return False
|
||||||
|
@ -104,4 +104,4 @@ def get_initializer(funcname):
|
|||||||
return funcname
|
return funcname
|
||||||
if funcname in INITIALIZERS:
|
if funcname in INITIALIZERS:
|
||||||
return INITIALIZERS.get(funcname)
|
return INITIALIZERS.get(funcname)
|
||||||
raise NameError(f'Initializer {funcname} was not found.')
|
raise NameError(f"Initializer {funcname} was not found.")
|
||||||
|
@ -3,15 +3,22 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _get_dp_dm(distances, targets, plabels):
|
def _get_matcher(targets, labels):
|
||||||
matcher = torch.eq(targets.unsqueeze(dim=1), plabels)
|
"""Returns a boolean tensor."""
|
||||||
if plabels.ndim == 2:
|
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
||||||
|
if labels.ndim == 2:
|
||||||
# if the labels are one-hot vectors
|
# if the labels are one-hot vectors
|
||||||
nclasses = targets.size()[1]
|
nclasses = targets.size()[1]
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dp_dm(distances, targets, plabels):
|
||||||
|
"""Returns the d+ and d- values for a batch of distances."""
|
||||||
|
matcher = _get_matcher(targets, plabels)
|
||||||
not_matcher = torch.bitwise_not(matcher)
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
|
|
||||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||||
d_matching = torch.where(matcher, distances, inf)
|
d_matching = torch.where(matcher, distances, inf)
|
||||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||||
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
18
prototorch/functions/similarities.py
Normal file
18
prototorch/functions/similarities.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
"""ProtoTorch similarity functions."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(x, y):
|
||||||
|
"""Compute the cosine similarity between :math:`x` and :math:`y`.
|
||||||
|
|
||||||
|
Expected dimension of x is 2.
|
||||||
|
Expected dimension of y is 2.
|
||||||
|
"""
|
||||||
|
norm_x = x.pow(2).sum(1).sqrt()
|
||||||
|
norm_y = y.pow(2).sum(1).sqrt()
|
||||||
|
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
||||||
|
epsilon = torch.finfo(norm_mat.dtype).eps
|
||||||
|
norm_mat.clamp_(min=epsilon)
|
||||||
|
similarities = (x @ y.T) / norm_mat
|
||||||
|
return similarities
|
@ -3,5 +3,5 @@
|
|||||||
from .prototypes import Prototypes1D
|
from .prototypes import Prototypes1D
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Prototypes1D',
|
"Prototypes1D",
|
||||||
]
|
]
|
||||||
|
@ -7,7 +7,7 @@ from prototorch.functions.losses import glvq_loss
|
|||||||
|
|
||||||
|
|
||||||
class GLVQLoss(torch.nn.Module):
|
class GLVQLoss(torch.nn.Module):
|
||||||
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.squashing = get_activation(squashing)
|
self.squashing = get_activation(squashing)
|
||||||
@ -18,3 +18,23 @@ class GLVQLoss(torch.nn.Module):
|
|||||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||||
return torch.sum(batch_loss, dim=0)
|
return torch.sum(batch_loss, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class NeuralGasEnergy(torch.nn.Module):
|
||||||
|
def __init__(self, lm):
|
||||||
|
super().__init__()
|
||||||
|
self.lm = lm
|
||||||
|
|
||||||
|
def forward(self, d):
|
||||||
|
order = torch.argsort(d, dim=1)
|
||||||
|
ranks = torch.argsort(order, dim=1)
|
||||||
|
cost = torch.sum(self._nghood_fn(ranks, self.lm) * d)
|
||||||
|
|
||||||
|
return cost, order
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"lambda: {self.lm}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _nghood_fn(rankings, lm):
|
||||||
|
return torch.exp(-rankings / lm)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from torch import nn
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from torch import nn
|
||||||
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
|
||||||
from prototorch.functions.normalization import orthogonalization
|
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||||||
|
tangent_distance)
|
||||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||||
|
from prototorch.functions.normalization import orthogonalization
|
||||||
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
|
||||||
|
|
||||||
class GTLVQ(nn.Module):
|
class GTLVQ(nn.Module):
|
||||||
@ -71,7 +73,7 @@ class GTLVQ(nn.Module):
|
|||||||
subspace_data=None,
|
subspace_data=None,
|
||||||
prototype_data=None,
|
prototype_data=None,
|
||||||
subspace_size=256,
|
subspace_size=256,
|
||||||
tangent_projection_type='local',
|
tangent_projection_type="local",
|
||||||
prototypes_per_class=2,
|
prototypes_per_class=2,
|
||||||
feature_dim=256,
|
feature_dim=256,
|
||||||
):
|
):
|
||||||
@ -82,37 +84,39 @@ class GTLVQ(nn.Module):
|
|||||||
self.feature_dim = feature_dim
|
self.feature_dim = feature_dim
|
||||||
|
|
||||||
if subspace_data is None:
|
if subspace_data is None:
|
||||||
raise ValueError('Init Data must be specified!')
|
raise ValueError("Init Data must be specified!")
|
||||||
|
|
||||||
self.tpt = tangent_projection_type
|
self.tpt = tangent_projection_type
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.tpt == 'local' or self.tpt == 'local_proj':
|
if self.tpt == "local" or self.tpt == "local_proj":
|
||||||
self.init_local_subspace(subspace_data)
|
self.init_local_subspace(subspace_data)
|
||||||
elif self.tpt == 'global':
|
elif self.tpt == "global":
|
||||||
self.init_gobal_subspace(subspace_data, subspace_size)
|
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||||
else:
|
else:
|
||||||
self.subspaces = None
|
self.subspaces = None
|
||||||
|
|
||||||
# Hypothesis-Margin-Classifier
|
# Hypothesis-Margin-Classifier
|
||||||
self.cls = Prototypes1D(input_dim=feature_dim,
|
self.cls = Prototypes1D(
|
||||||
|
input_dim=feature_dim,
|
||||||
prototypes_per_class=prototypes_per_class,
|
prototypes_per_class=prototypes_per_class,
|
||||||
nclasses=num_classes,
|
nclasses=num_classes,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=prototype_data)
|
data=prototype_data,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Tangent Projection
|
# Tangent Projection
|
||||||
if self.tpt == 'local_proj':
|
if self.tpt == "local_proj":
|
||||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||||
1).unsqueeze(2)
|
1).unsqueeze(2))
|
||||||
dis, proj_x = self.local_tangent_projection(x_conform)
|
dis, proj_x = self.local_tangent_projection(x_conform)
|
||||||
|
|
||||||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
||||||
self.feature_dim)
|
self.feature_dim)
|
||||||
return proj_x, dis
|
return proj_x, dis
|
||||||
elif self.tpt == "local":
|
elif self.tpt == "local":
|
||||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||||
1).unsqueeze(2)
|
1).unsqueeze(2))
|
||||||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
dis = tangent_distance(x_conform, self.cls.prototypes,
|
||||||
self.subspaces)
|
self.subspaces)
|
||||||
elif self.tpt == "gloabl":
|
elif self.tpt == "gloabl":
|
||||||
@ -127,25 +131,27 @@ class GTLVQ(nn.Module):
|
|||||||
_, _, v = torch.svd(data)
|
_, _, v = torch.svd(data)
|
||||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||||
subspaces = subspace[:, :num_subspaces]
|
subspaces = subspace[:, :num_subspaces]
|
||||||
self.subspaces = torch.nn.Parameter(
|
self.subspaces = (torch.nn.Parameter(
|
||||||
subspaces).clone().detach().requires_grad_(True)
|
subspaces).clone().detach().requires_grad_(True))
|
||||||
|
|
||||||
def init_local_subspace(self, data):
|
def init_local_subspace(self, data):
|
||||||
_, _, v = torch.svd(data)
|
_, _, v = torch.svd(data)
|
||||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||||
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
||||||
self.num_protos, 0)
|
self.num_protos, 0)
|
||||||
self.subspaces = torch.nn.Parameter(
|
self.subspaces = (torch.nn.Parameter(
|
||||||
subspaces).clone().detach().requires_grad_(True)
|
subspaces).clone().detach().requires_grad_(True))
|
||||||
|
|
||||||
def global_tangent_distances(self, x):
|
def global_tangent_distances(self, x):
|
||||||
# Tangent Projection
|
# Tangent Projection
|
||||||
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
|
x, projected_prototypes = (
|
||||||
|
x @ self.subspaces,
|
||||||
|
self.cls.prototypes @ self.subspaces,
|
||||||
|
)
|
||||||
# Euclidean Distance
|
# Euclidean Distance
|
||||||
return euclidean_distance_matrix(x, projected_prototypes)
|
return euclidean_distance_matrix(x, projected_prototypes)
|
||||||
|
|
||||||
def local_tangent_projection(self,
|
def local_tangent_projection(self, signals):
|
||||||
signals):
|
|
||||||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
||||||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||||
@ -183,8 +189,7 @@ class GTLVQ(nn.Module):
|
|||||||
def orthogonalize_subspace(self):
|
def orthogonalize_subspace(self):
|
||||||
if self.subspaces is not None:
|
if self.subspaces is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ortho_subpsaces = orthogonalization(
|
ortho_subpsaces = (orthogonalization(self.subspaces)
|
||||||
self.subspaces
|
if self.tpt == "global" else
|
||||||
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
|
torch.nn.init.orthogonal_(self.subspaces))
|
||||||
self.subspaces)
|
|
||||||
self.subspaces.copy_(ortho_subpsaces)
|
self.subspaces.copy_(ortho_subpsaces)
|
||||||
|
@ -29,14 +29,19 @@ class Prototypes1D(_Prototypes):
|
|||||||
|
|
||||||
TODO Complete this doc-string.
|
TODO Complete this doc-string.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="ones",
|
prototype_initializer="ones",
|
||||||
prototype_distribution=None,
|
prototype_distribution=None,
|
||||||
data=None,
|
data=None,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
one_hot_labels=False,
|
one_hot_labels=False,
|
||||||
**kwargs):
|
**kwargs,
|
||||||
|
):
|
||||||
|
warnings.warn(
|
||||||
|
PendingDeprecationWarning(
|
||||||
|
"Prototypes1D will be replaced in future versions."))
|
||||||
|
|
||||||
# Convert tensors to python lists before processing
|
# Convert tensors to python lists before processing
|
||||||
if prototype_distribution is not None:
|
if prototype_distribution is not None:
|
||||||
|
@ -1 +0,0 @@
|
|||||||
from .colors import color_scheme, get_legend_handles
|
|
46
prototorch/utils/celluloid.py
Normal file
46
prototorch/utils/celluloid.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from matplotlib.animation import ArtistAnimation
|
||||||
|
from matplotlib.artist import Artist
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
__version__ = "0.2.0"
|
||||||
|
|
||||||
|
|
||||||
|
class Camera:
|
||||||
|
"""Make animations easier."""
|
||||||
|
def __init__(self, figure: Figure) -> None:
|
||||||
|
"""Create camera from matplotlib figure."""
|
||||||
|
self._figure = figure
|
||||||
|
# need to keep track off artists for each axis
|
||||||
|
self._offsets: Dict[str, Dict[int, int]] = {
|
||||||
|
k: defaultdict(int)
|
||||||
|
for k in
|
||||||
|
["collections", "patches", "lines", "texts", "artists", "images"]
|
||||||
|
}
|
||||||
|
self._photos: List[List[Artist]] = []
|
||||||
|
|
||||||
|
def snap(self) -> List[Artist]:
|
||||||
|
"""Capture current state of the figure."""
|
||||||
|
frame_artists: List[Artist] = []
|
||||||
|
for i, axis in enumerate(self._figure.axes):
|
||||||
|
if axis.legend_ is not None:
|
||||||
|
axis.add_artist(axis.legend_)
|
||||||
|
for name in self._offsets:
|
||||||
|
new_artists = getattr(axis, name)[self._offsets[name][i]:]
|
||||||
|
frame_artists += new_artists
|
||||||
|
self._offsets[name][i] += len(new_artists)
|
||||||
|
self._photos.append(frame_artists)
|
||||||
|
return frame_artists
|
||||||
|
|
||||||
|
def animate(self, *args, **kwargs) -> ArtistAnimation:
|
||||||
|
"""Animate the snapshots taken.
|
||||||
|
Uses matplotlib.animation.ArtistAnimation
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ArtistAnimation
|
||||||
|
"""
|
||||||
|
return ArtistAnimation(self._figure, self._photos, *args, **kwargs)
|
@ -1,13 +1,14 @@
|
|||||||
"""ProtoFlow color utilities."""
|
"""ProtoFlow color utilities."""
|
||||||
|
|
||||||
from matplotlib import cm
|
|
||||||
from matplotlib.colors import Normalize
|
|
||||||
from matplotlib.colors import to_hex
|
|
||||||
from matplotlib.colors import to_rgb
|
|
||||||
import matplotlib.lines as mlines
|
import matplotlib.lines as mlines
|
||||||
|
from matplotlib import cm
|
||||||
|
from matplotlib.colors import Normalize, to_hex, to_rgb
|
||||||
|
|
||||||
|
|
||||||
def color_scheme(n, cmap="viridis", form="hex", tikz=False,
|
def color_scheme(n,
|
||||||
|
cmap="viridis",
|
||||||
|
form="hex",
|
||||||
|
tikz=False,
|
||||||
zero_indexed=False):
|
zero_indexed=False):
|
||||||
"""Return *n* colors from the color scheme.
|
"""Return *n* colors from the color scheme.
|
||||||
|
|
||||||
@ -57,13 +58,16 @@ def get_legend_handles(labels, marker="dots", zero_indexed=False):
|
|||||||
zero_indexed=zero_indexed)
|
zero_indexed=zero_indexed)
|
||||||
for label, color in zip(labels, colors.values()):
|
for label, color in zip(labels, colors.values()):
|
||||||
if marker == "dots":
|
if marker == "dots":
|
||||||
handle = mlines.Line2D([], [],
|
handle = mlines.Line2D(
|
||||||
|
[],
|
||||||
|
[],
|
||||||
color="white",
|
color="white",
|
||||||
markerfacecolor=color,
|
markerfacecolor=color,
|
||||||
marker="o",
|
marker="o",
|
||||||
markersize=10,
|
markersize=10,
|
||||||
markeredgecolor="k",
|
markeredgecolor="k",
|
||||||
label=label)
|
label=label,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
handle = mlines.Line2D([], [],
|
handle = mlines.Line2D([], [],
|
||||||
color=color,
|
color=color,
|
||||||
|
243
prototorch/utils/utils.py
Normal file
243
prototorch/utils/utils.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
"""Utilities that provide various small functionalities."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def progressbar(title, value, end, bar_width=20):
|
||||||
|
percent = float(value) / end
|
||||||
|
arrow = "=" * int(round(percent * bar_width) - 1) + ">"
|
||||||
|
spaces = "." * (bar_width - len(arrow))
|
||||||
|
sys.stdout.write("\r{}: [{}] {}%".format(title, arrow + spaces,
|
||||||
|
int(round(percent * 100))))
|
||||||
|
sys.stdout.flush()
|
||||||
|
if percent == 1.0:
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def prettify_string(inputs, start="", sep=" ", end="\n"):
|
||||||
|
outputs = start + " ".join(inputs.split()) + end
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def pretty_print(inputs):
|
||||||
|
print(prettify_string(inputs))
|
||||||
|
|
||||||
|
|
||||||
|
def writelog(self, *logs, logdir="./logs", logfile="run.txt"):
|
||||||
|
f = os.path.join(logdir, logfile)
|
||||||
|
with open(f, "a+") as fh:
|
||||||
|
for log in logs:
|
||||||
|
fh.write(log)
|
||||||
|
fh.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def start_tensorboard(self, logdir="./logs"):
|
||||||
|
cmd = f"tensorboard --logdir={logdir} --port=6006"
|
||||||
|
os.system(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def make_directory(save_dir):
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
print(f"Making directory {save_dir}.")
|
||||||
|
os.mkdir(save_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def make_gif(filenames, duration, output_file=None):
|
||||||
|
try:
|
||||||
|
import imageio
|
||||||
|
except ModuleNotFoundError as e:
|
||||||
|
print("Please install Protoflow with [other] extra requirements.")
|
||||||
|
raise (e)
|
||||||
|
|
||||||
|
images = list()
|
||||||
|
for filename in filenames:
|
||||||
|
images.append(imageio.imread(filename))
|
||||||
|
if not output_file:
|
||||||
|
output_file = f"makegif.gif"
|
||||||
|
if images:
|
||||||
|
imageio.mimwrite(output_file, images, duration=duration)
|
||||||
|
|
||||||
|
|
||||||
|
def gif_from_dir(directory,
|
||||||
|
duration,
|
||||||
|
prefix="",
|
||||||
|
output_file=None,
|
||||||
|
verbose=True):
|
||||||
|
images = os.listdir(directory)
|
||||||
|
if verbose:
|
||||||
|
print(f"Making gif from {len(images)} images under {directory}.")
|
||||||
|
filenames = list()
|
||||||
|
# Sort images
|
||||||
|
images = sorted(
|
||||||
|
images,
|
||||||
|
key=lambda img: int(os.path.splitext(img)[0].replace(prefix, "")))
|
||||||
|
for image in images:
|
||||||
|
fname = os.path.join(directory, image)
|
||||||
|
filenames.append(fname)
|
||||||
|
if not output_file:
|
||||||
|
output_file = os.path.join(directory, "makegif.gif")
|
||||||
|
make_gif(filenames=filenames, duration=duration, output_file=output_file)
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy_score(y_true, y_pred):
|
||||||
|
accuracy = np.sum(y_true == y_pred)
|
||||||
|
normalized_acc = accuracy / float(len(y_true))
|
||||||
|
return normalized_acc
|
||||||
|
|
||||||
|
|
||||||
|
def predict_and_score(clf,
|
||||||
|
x_test,
|
||||||
|
y_test,
|
||||||
|
verbose=False,
|
||||||
|
title="Test accuracy"):
|
||||||
|
y_pred = clf.predict(x_test)
|
||||||
|
accuracy = np.sum(y_test == y_pred)
|
||||||
|
normalized_acc = accuracy / float(len(y_test))
|
||||||
|
if verbose:
|
||||||
|
print(f"{title}: {normalized_acc * 100:06.04f}%")
|
||||||
|
return normalized_acc
|
||||||
|
|
||||||
|
|
||||||
|
def remove_nan_rows(arr):
|
||||||
|
"""Remove all rows with `nan` values in `arr`."""
|
||||||
|
mask = np.isnan(arr).any(axis=1)
|
||||||
|
return arr[~mask]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_nan_cols(arr):
|
||||||
|
"""Remove all columns with `nan` values in `arr`."""
|
||||||
|
mask = np.isnan(arr).any(axis=0)
|
||||||
|
return arr[~mask]
|
||||||
|
|
||||||
|
|
||||||
|
def replace_in(arr, replacement_dict, inplace=False):
|
||||||
|
"""Replace the keys found in `arr` with the values from
|
||||||
|
the `replacement_dict`.
|
||||||
|
"""
|
||||||
|
if inplace:
|
||||||
|
new_arr = arr
|
||||||
|
else:
|
||||||
|
import copy
|
||||||
|
|
||||||
|
new_arr = copy.deepcopy(arr)
|
||||||
|
for k, v in replacement_dict.items():
|
||||||
|
new_arr[arr == k] = v
|
||||||
|
return new_arr
|
||||||
|
|
||||||
|
|
||||||
|
def train_test_split(data, train=0.7, val=0.15, shuffle=None, return_xy=False):
|
||||||
|
"""Split a classification dataset in such a way so as to
|
||||||
|
preserve the class distribution in subsamples of the dataset.
|
||||||
|
"""
|
||||||
|
if train + val > 1.0:
|
||||||
|
raise ValueError("Invalid split values for train and val.")
|
||||||
|
Y = data[:, -1]
|
||||||
|
labels = set(Y)
|
||||||
|
hist = dict()
|
||||||
|
for l in labels:
|
||||||
|
data_l = data[Y == l]
|
||||||
|
nl = len(data_l)
|
||||||
|
nl_train = int(nl * train)
|
||||||
|
nl_val = int(nl * val)
|
||||||
|
nl_test = nl - (nl_train + nl_val)
|
||||||
|
hist[l] = (nl_train, nl_val, nl_test)
|
||||||
|
|
||||||
|
train_data = list()
|
||||||
|
val_data = list()
|
||||||
|
test_data = list()
|
||||||
|
for l, (nl_train, nl_val, nl_test) in hist.items():
|
||||||
|
data_l = data[Y == l]
|
||||||
|
if shuffle:
|
||||||
|
np.random.shuffle(data_l)
|
||||||
|
train_l = data_l[:nl_train]
|
||||||
|
val_l = data_l[nl_train:nl_train + nl_val]
|
||||||
|
test_l = data_l[nl_train + nl_val:nl_train + nl_val + nl_test]
|
||||||
|
train_data.append(train_l)
|
||||||
|
val_data.append(val_l)
|
||||||
|
test_data.append(test_l)
|
||||||
|
|
||||||
|
def _squash(data_list):
|
||||||
|
data = np.array(data_list[0])
|
||||||
|
for item in data_list[1:]:
|
||||||
|
data = np.vstack((data, np.array(item)))
|
||||||
|
return data
|
||||||
|
|
||||||
|
train_data = _squash(train_data)
|
||||||
|
if val_data:
|
||||||
|
val_data = _squash(val_data)
|
||||||
|
if test_data:
|
||||||
|
test_data = _squash(test_data)
|
||||||
|
if return_xy:
|
||||||
|
x_train = train_data[:, :-1]
|
||||||
|
y_train = train_data[:, -1]
|
||||||
|
x_val = val_data[:, :-1]
|
||||||
|
y_val = val_data[:, -1]
|
||||||
|
x_test = test_data[:, :-1]
|
||||||
|
y_test = test_data[:, -1]
|
||||||
|
return (x_train, y_train), (x_val, y_val), (x_test, y_test)
|
||||||
|
return train_data, val_data, test_data
|
||||||
|
|
||||||
|
|
||||||
|
def class_histogram(data, title="Untitled"):
|
||||||
|
plt.figure(title)
|
||||||
|
plt.clf()
|
||||||
|
plt.title(title)
|
||||||
|
dist, counts = np.unique(data[:, -1], return_counts=True)
|
||||||
|
plt.bar(dist, counts)
|
||||||
|
plt.xticks(dist)
|
||||||
|
print("Call matplotlib.pyplot.show() to see the plot.")
|
||||||
|
|
||||||
|
|
||||||
|
def ntimer(n=10):
|
||||||
|
"""Wraps a function which wraps another function to time it."""
|
||||||
|
if n < 1:
|
||||||
|
raise (Exception(f"Invalid n = {n} given."))
|
||||||
|
|
||||||
|
def timer(func):
|
||||||
|
"""Wraps `func` with a timer and returns the wrapped `func`."""
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
rv = None
|
||||||
|
before = time()
|
||||||
|
for _ in range(n):
|
||||||
|
rv = func(*args, **kwargs)
|
||||||
|
after = time()
|
||||||
|
elapsed = after - before
|
||||||
|
print(f"Elapsed: {elapsed*1e3:02.02f} ms")
|
||||||
|
return rv
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return timer
|
||||||
|
|
||||||
|
|
||||||
|
def memoize(verbose=True):
|
||||||
|
"""Wraps a function which wraps another function that memoizes."""
|
||||||
|
def memoizer(func):
|
||||||
|
"""Memoize (cache) return values of `func`.
|
||||||
|
Wraps `func` and returns the wrapped `func` so that `func`
|
||||||
|
is executed when the results are not available in the cache.
|
||||||
|
"""
|
||||||
|
cache = {}
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
t = (pickle.dumps(args), pickle.dumps(kwargs))
|
||||||
|
if t not in cache:
|
||||||
|
if verbose:
|
||||||
|
print(f"Adding NEW rv {func.__name__}{args}{kwargs} "
|
||||||
|
"to cache.")
|
||||||
|
cache[t] = func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
if verbose:
|
||||||
|
print(f"Using OLD rv {func.__name__}{args}{kwargs} "
|
||||||
|
"from cache.")
|
||||||
|
return cache[t]
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return memoizer
|
5
setup.py
5
setup.py
@ -8,8 +8,7 @@
|
|||||||
|
|
||||||
ProtoTorch Core Package
|
ProtoTorch Core Package
|
||||||
"""
|
"""
|
||||||
from setuptools import setup
|
from setuptools import find_packages, setup
|
||||||
from setuptools import find_packages
|
|
||||||
|
|
||||||
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
||||||
@ -42,7 +41,7 @@ ALL = DOCS + DATASETS + EXAMPLES + TESTS
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="prototorch",
|
name="prototorch",
|
||||||
version="0.3.0-dev0",
|
version="0.4.0",
|
||||||
description="Highly extensible, GPU-supported "
|
description="Highly extensible, GPU-supported "
|
||||||
"Learning Vector Quantization (LVQ) toolbox "
|
"Learning Vector Quantization (LVQ) toolbox "
|
||||||
"built using PyTorch and its nn API.",
|
"built using PyTorch and its nn API.",
|
||||||
|
@ -12,26 +12,26 @@ from prototorch.datasets import abstract, tecator
|
|||||||
class TestAbstract(unittest.TestCase):
|
class TestAbstract(unittest.TestCase):
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.Dataset('./artifacts')[0]
|
abstract.Dataset("./artifacts")[0]
|
||||||
|
|
||||||
def test_len(self):
|
def test_len(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
len(abstract.Dataset('./artifacts'))
|
len(abstract.Dataset("./artifacts"))
|
||||||
|
|
||||||
|
|
||||||
class TestProtoDataset(unittest.TestCase):
|
class TestProtoDataset(unittest.TestCase):
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.ProtoDataset('./artifacts')[0]
|
abstract.ProtoDataset("./artifacts")[0]
|
||||||
|
|
||||||
def test_download(self):
|
def test_download(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.ProtoDataset('./artifacts').download()
|
abstract.ProtoDataset("./artifacts").download()
|
||||||
|
|
||||||
|
|
||||||
class TestTecator(unittest.TestCase):
|
class TestTecator(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.artifacts_dir = './artifacts/Tecator'
|
self.artifacts_dir = "./artifacts/Tecator"
|
||||||
self._remove_artifacts()
|
self._remove_artifacts()
|
||||||
|
|
||||||
def _remove_artifacts(self):
|
def _remove_artifacts(self):
|
||||||
@ -39,23 +39,23 @@ class TestTecator(unittest.TestCase):
|
|||||||
shutil.rmtree(self.artifacts_dir)
|
shutil.rmtree(self.artifacts_dir)
|
||||||
|
|
||||||
def test_download_false(self):
|
def test_download_false(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
self._remove_artifacts()
|
self._remove_artifacts()
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
_ = tecator.Tecator(rootdir, download=False)
|
_ = tecator.Tecator(rootdir, download=False)
|
||||||
|
|
||||||
def test_download_caching(self):
|
def test_download_caching(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
||||||
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
||||||
|
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
||||||
self.assertTrue('Split: Train' in train.__repr__())
|
self.assertTrue("Split: Train" in train.__repr__())
|
||||||
|
|
||||||
def test_download_train(self):
|
def test_download_train(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
train = tecator.Tecator(root=rootdir,
|
train = tecator.Tecator(root=rootdir,
|
||||||
train=True,
|
train=True,
|
||||||
download=True,
|
download=True,
|
||||||
@ -67,7 +67,7 @@ class TestTecator(unittest.TestCase):
|
|||||||
self.assertEqual(x_train.shape[1], 100)
|
self.assertEqual(x_train.shape[1], 100)
|
||||||
|
|
||||||
def test_download_test(self):
|
def test_download_test(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
x_test, y_test = test.data, test.targets
|
x_test, y_test = test.data, test.targets
|
||||||
self.assertEqual(x_test.shape[0], 71)
|
self.assertEqual(x_test.shape[0], 71)
|
||||||
@ -75,19 +75,19 @@ class TestTecator(unittest.TestCase):
|
|||||||
self.assertEqual(x_test.shape[1], 100)
|
self.assertEqual(x_test.shape[1], 100)
|
||||||
|
|
||||||
def test_class_to_idx(self):
|
def test_class_to_idx(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
_ = test.class_to_idx
|
_ = test.class_to_idx
|
||||||
|
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
x, y = test[0]
|
x, y = test[0]
|
||||||
self.assertEqual(x.shape[0], 100)
|
self.assertEqual(x.shape[0], 100)
|
||||||
self.assertIsInstance(y, int)
|
self.assertIsInstance(y, int)
|
||||||
|
|
||||||
def test_loadable_with_dataloader(self):
|
def test_loadable_with_dataloader(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from prototorch.functions import (activations, competitions, distances,
|
|||||||
|
|
||||||
class TestActivations(unittest.TestCase):
|
class TestActivations(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.flist = ['identity', 'sigmoid_beta', 'swish_beta']
|
self.flist = ["identity", "sigmoid_beta", "swish_beta"]
|
||||||
self.x = torch.randn(1024, 1)
|
self.x = torch.randn(1024, 1)
|
||||||
|
|
||||||
def test_registry(self):
|
def test_registry(self):
|
||||||
@ -39,7 +39,7 @@ class TestActivations(unittest.TestCase):
|
|||||||
self.assertEqual(1, f(1))
|
self.assertEqual(1, f(1))
|
||||||
|
|
||||||
def test_unknown_deserialization(self):
|
def test_unknown_deserialization(self):
|
||||||
for funcname in ['blubb', 'foobar']:
|
for funcname in ["blubb", "foobar"]:
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = activations.get_activation(funcname)
|
_ = activations.get_activation(funcname)
|
||||||
|
|
||||||
@ -52,7 +52,7 @@ class TestActivations(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_sigmoid_beta1(self):
|
def test_sigmoid_beta1(self):
|
||||||
actual = activations.sigmoid_beta(self.x, beta=torch.tensor(1))
|
actual = activations.sigmoid_beta(self.x, beta=1.0)
|
||||||
desired = torch.sigmoid(self.x)
|
desired = torch.sigmoid(self.x)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@ -60,7 +60,7 @@ class TestActivations(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_swish_beta1(self):
|
def test_swish_beta1(self):
|
||||||
actual = activations.swish_beta(self.x, beta=torch.tensor(1))
|
actual = activations.swish_beta(self.x, beta=1.0)
|
||||||
desired = self.x * torch.sigmoid(self.x)
|
desired = self.x * torch.sigmoid(self.x)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@ -76,7 +76,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_wtac(self):
|
def test_wtac(self):
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
actual = competitions.wtac(d, labels)
|
actual = competitions.wtac(d, labels)
|
||||||
desired = torch.tensor([2, 0])
|
desired = torch.tensor([2, 0])
|
||||||
@ -86,7 +86,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_wtac_unequal_dist(self):
|
def test_wtac_unequal_dist(self):
|
||||||
d = torch.tensor([[2., 3., 4.], [2., 3., 1.]])
|
d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
|
||||||
labels = torch.tensor([0, 1, 1])
|
labels = torch.tensor([0, 1, 1])
|
||||||
actual = competitions.wtac(d, labels)
|
actual = competitions.wtac(d, labels)
|
||||||
desired = torch.tensor([0, 1])
|
desired = torch.tensor([0, 1])
|
||||||
@ -96,7 +96,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_wtac_one_hot(self):
|
def test_wtac_one_hot(self):
|
||||||
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
|
||||||
labels = torch.tensor([[0, 1], [1, 0]])
|
labels = torch.tensor([[0, 1], [1, 0]])
|
||||||
actual = competitions.wtac(d, labels)
|
actual = competitions.wtac(d, labels)
|
||||||
desired = torch.tensor([[0, 1], [1, 0]])
|
desired = torch.tensor([[0, 1], [1, 0]])
|
||||||
@ -106,38 +106,38 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_stratified_min(self):
|
def test_stratified_min(self):
|
||||||
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 0, 1, 2])
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = competitions.stratified_min(d, labels)
|
||||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_stratified_min_one_hot(self):
|
def test_stratified_min_one_hot(self):
|
||||||
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 0, 1, 2])
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
labels = torch.eye(3)[labels]
|
labels = torch.eye(3)[labels]
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = competitions.stratified_min(d, labels)
|
||||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_stratified_min_simple(self):
|
def test_stratified_min_simple(self):
|
||||||
d = torch.tensor([[0., 2., 3.], [8., 0, 1]])
|
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 1, 2])
|
labels = torch.tensor([0, 1, 2])
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = competitions.stratified_min(d, labels)
|
||||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_knnc_k1(self):
|
def test_knnc_k1(self):
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
|
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
|
||||||
desired = torch.tensor([2, 0])
|
desired = torch.tensor([2, 0])
|
||||||
@ -194,12 +194,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=2,
|
p=2,
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)**2
|
)**2)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=2)
|
decimal=2)
|
||||||
@ -254,14 +254,14 @@ class TestDistances(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_lpnorm_pinf(self):
|
def test_lpnorm_pinf(self):
|
||||||
actual = distances.lpnorm_distance(self.x, self.y, p=float('inf'))
|
actual = distances.lpnorm_distance(self.x, self.y, p=float("inf"))
|
||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=float('inf'),
|
p=float("inf"),
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)
|
)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
@ -275,12 +275,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=2,
|
p=2,
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)**2
|
)**2)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=2)
|
decimal=2)
|
||||||
@ -293,12 +293,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=2,
|
p=2,
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)**2
|
)**2)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=2)
|
decimal=2)
|
||||||
@ -311,8 +311,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
class TestInitializers(unittest.TestCase):
|
class TestInitializers(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.flist = [
|
self.flist = [
|
||||||
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
|
"zeros",
|
||||||
'stratified_random'
|
"ones",
|
||||||
|
"rand",
|
||||||
|
"randn",
|
||||||
|
"stratified_mean",
|
||||||
|
"stratified_random",
|
||||||
]
|
]
|
||||||
self.x = torch.tensor(
|
self.x = torch.tensor(
|
||||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
||||||
@ -340,7 +344,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
self.assertEqual(1, f(1))
|
self.assertEqual(1, f(1))
|
||||||
|
|
||||||
def test_unknown_deserialization(self):
|
def test_unknown_deserialization(self):
|
||||||
for funcname in ['blubb', 'foobar']:
|
for funcname in ["blubb", "foobar"]:
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = initializers.get_initializer(funcname)
|
_ = initializers.get_initializer(funcname)
|
||||||
|
|
||||||
@ -383,7 +387,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_equal1(self):
|
def test_stratified_mean_equal1(self):
|
||||||
pdist = torch.tensor([1, 1])
|
pdist = torch.tensor([1, 1])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]])
|
desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -393,7 +397,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([1, 1])
|
pdist = torch.tensor([1, 1])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
False)
|
False)
|
||||||
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.]])
|
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -402,8 +406,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_equal2(self):
|
def test_stratified_mean_equal2(self):
|
||||||
pdist = torch.tensor([2, 2])
|
pdist = torch.tensor([2, 2])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.],
|
desired = torch.tensor([[5.0, 5.0, 5.0], [5.0, 5.0, 5.0],
|
||||||
[1., 1., 1.]])
|
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -413,8 +417,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([2, 2])
|
pdist = torch.tensor([2, 2])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
False)
|
False)
|
||||||
desired = torch.tensor([[0., -1., -2.], [0., -1., -2.], [0., 0., 0.],
|
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, -1.0, -2.0],
|
||||||
[0., 0., 0.]])
|
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -423,8 +427,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_unequal(self):
|
def test_stratified_mean_unequal(self):
|
||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
|
||||||
[1., 1., 1.]])
|
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -434,8 +438,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
False)
|
False)
|
||||||
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
|
||||||
[0., 0., 0.]])
|
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -444,8 +448,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_unequal_one_hot(self):
|
def test_stratified_mean_unequal_one_hot(self):
|
||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
y = torch.eye(2)[self.y]
|
y = torch.eye(2)[self.y]
|
||||||
desired1 = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
|
||||||
[1., 1., 1.]])
|
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||||
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
|
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
|
||||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||||
@ -460,8 +464,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
y = torch.eye(2)[self.y]
|
y = torch.eye(2)[self.y]
|
||||||
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
|
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
|
||||||
desired1 = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
desired1 = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
|
||||||
[0., 0., 0.]])
|
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||||
desired1,
|
desired1,
|
||||||
|
@ -29,10 +29,12 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
|
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
|
||||||
|
|
||||||
def test_prototypes1d_init_without_pdist(self):
|
def test_prototypes1d_init_without_pdist(self):
|
||||||
p1 = prototypes.Prototypes1D(input_dim=6,
|
p1 = prototypes.Prototypes1D(
|
||||||
|
input_dim=6,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=4,
|
prototypes_per_class=4,
|
||||||
prototype_initializer='ones')
|
prototype_initializer="ones",
|
||||||
|
)
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.ones(8, 6)
|
desired = torch.ones(8, 6)
|
||||||
@ -45,7 +47,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
pdist = [2, 2]
|
pdist = [2, 2]
|
||||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||||
prototype_distribution=pdist,
|
prototype_distribution=pdist,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer="zeros")
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(4, 3)
|
desired = torch.zeros(4, 3)
|
||||||
@ -60,14 +62,15 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=3,
|
input_dim=3,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=None)
|
data=None,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_torch_pdist(self):
|
def test_prototypes1d_init_torch_pdist(self):
|
||||||
pdist = torch.tensor([2, 2])
|
pdist = torch.tensor([2, 2])
|
||||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||||
prototype_distribution=pdist,
|
prototype_distribution=pdist,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer="zeros")
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(4, 3)
|
desired = torch.zeros(4, 3)
|
||||||
@ -77,24 +80,30 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_prototypes1d_init_without_inputdim_with_data(self):
|
def test_prototypes1d_init_without_inputdim_with_data(self):
|
||||||
_ = prototypes.Prototypes1D(nclasses=2,
|
_ = prototypes.Prototypes1D(
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=[[[1.], [0.]], [1, 0]])
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_int_data(self):
|
|
||||||
_ = prototypes.Prototypes1D(nclasses=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer='stratified_mean',
|
|
||||||
data=[[[1], [0]], [1, 0]])
|
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_without_data(self):
|
|
||||||
_ = prototypes.Prototypes1D(input_dim=1,
|
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
|
data=[[[1.0], [0.0]], [1, 0]],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prototypes1d_init_with_int_data(self):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
nclasses=2,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer="stratified_mean",
|
||||||
|
data=[[[1], [0]], [1, 0]],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prototypes1d_init_one_hot_without_data(self):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
input_dim=1,
|
||||||
|
nclasses=2,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer="stratified_mean",
|
||||||
data=None,
|
data=None,
|
||||||
one_hot_labels=True)
|
one_hot_labels=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_labels_false(self):
|
def test_prototypes1d_init_one_hot_labels_false(self):
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
|
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
|
||||||
@ -105,9 +114,10 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=([[0.], [1.]], [[0, 1], [1, 0]]),
|
data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
|
||||||
one_hot_labels=False)
|
one_hot_labels=False,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
|
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
||||||
@ -118,9 +128,10 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=([[0.], [1.]], [0, 1]),
|
data=([[0.0], [1.0]], [0, 1]),
|
||||||
one_hot_labels=True)
|
one_hot_labels=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_labels_true(self):
|
def test_prototypes1d_init_one_hot_labels_true(self):
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
||||||
@ -132,25 +143,27 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=([[0.], [1.]], [[0], [1]]),
|
data=([[0.0], [1.0]], [[0], [1]]),
|
||||||
one_hot_labels=True)
|
one_hot_labels=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_with_int_dtype(self):
|
def test_prototypes1d_init_with_int_dtype(self):
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1], [0]], [1, 0]],
|
data=[[[1], [0]], [1, 0]],
|
||||||
dtype=torch.int32)
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_inputndim_with_data(self):
|
def test_prototypes1d_inputndim_with_data(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(input_dim=1,
|
_ = prototypes.Prototypes1D(input_dim=1,
|
||||||
nclasses=1,
|
nclasses=1,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
data=[[1.], [1]])
|
data=[[1.0], [1]])
|
||||||
|
|
||||||
def test_prototypes1d_inputdim_with_data(self):
|
def test_prototypes1d_inputdim_with_data(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -158,8 +171,9 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=2,
|
input_dim=2,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1.], [0.]], [1, 0]])
|
data=[[[1.0], [0.0]], [1, 0]],
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_nclasses_with_data(self):
|
def test_prototypes1d_nclasses_with_data(self):
|
||||||
"""Test ValueError raise if provided `nclasses` is not the same
|
"""Test ValueError raise if provided `nclasses` is not the same
|
||||||
@ -170,13 +184,14 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=1,
|
nclasses=1,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1.], [2.]], [1, 2]])
|
data=[[[1.0], [2.0]], [1, 2]],
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_with_ppc(self):
|
def test_prototypes1d_init_with_ppc(self):
|
||||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||||
prototypes_per_class=2,
|
prototypes_per_class=2,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer="zeros")
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(4, 3)
|
desired = torch.zeros(4, 3)
|
||||||
@ -186,9 +201,11 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_prototypes1d_init_with_pdist(self):
|
def test_prototypes1d_init_with_pdist(self):
|
||||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
p1 = prototypes.Prototypes1D(
|
||||||
|
data=[self.x, self.y],
|
||||||
prototype_distribution=[6, 9],
|
prototype_distribution=[6, 9],
|
||||||
prototype_initializer='zeros')
|
prototype_initializer="zeros",
|
||||||
|
)
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(15, 3)
|
desired = torch.zeros(15, 3)
|
||||||
@ -201,10 +218,12 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
def my_initializer(*args, **kwargs):
|
def my_initializer(*args, **kwargs):
|
||||||
return torch.full((2, 99), 99.0), torch.tensor([0, 1])
|
return torch.full((2, 99), 99.0), torch.tensor([0, 1])
|
||||||
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=99,
|
p1 = prototypes.Prototypes1D(
|
||||||
|
input_dim=99,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer=my_initializer)
|
prototype_initializer=my_initializer,
|
||||||
|
)
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = 99 * torch.ones(2, 99)
|
desired = 99 * torch.ones(2, 99)
|
||||||
@ -231,7 +250,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
def test_prototypes1d_validate_extra_repr_not_empty(self):
|
def test_prototypes1d_validate_extra_repr_not_empty(self):
|
||||||
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
||||||
rep = p1.extra_repr()
|
rep = p1.extra_repr()
|
||||||
self.assertNotEqual(rep, '')
|
self.assertNotEqual(rep, "")
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
del self.x, self.y, self.gen
|
del self.x, self.y, self.gen
|
||||||
@ -243,11 +262,11 @@ class TestLosses(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_glvqloss_init(self):
|
def test_glvqloss_init(self):
|
||||||
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
|
_ = losses.GLVQLoss(0, "swish_beta", beta=20)
|
||||||
|
|
||||||
def test_glvqloss_forward_1ppc(self):
|
def test_glvqloss_forward_1ppc(self):
|
||||||
criterion = losses.GLVQLoss(margin=0,
|
criterion = losses.GLVQLoss(margin=0,
|
||||||
squashing='sigmoid_beta',
|
squashing="sigmoid_beta",
|
||||||
beta=100)
|
beta=100)
|
||||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
labels = torch.tensor([0, 1])
|
labels = torch.tensor([0, 1])
|
||||||
@ -259,7 +278,7 @@ class TestLosses(unittest.TestCase):
|
|||||||
|
|
||||||
def test_glvqloss_forward_2ppc(self):
|
def test_glvqloss_forward_2ppc(self):
|
||||||
criterion = losses.GLVQLoss(margin=0,
|
criterion = losses.GLVQLoss(margin=0,
|
||||||
squashing='sigmoid_beta',
|
squashing="sigmoid_beta",
|
||||||
beta=100)
|
beta=100)
|
||||||
d = torch.stack([
|
d = torch.stack([
|
||||||
torch.ones(100),
|
torch.ones(100),
|
||||||
|
Loading…
Reference in New Issue
Block a user