Automatic Formatting.

This commit is contained in:
Alexander Engelsberger 2021-04-23 17:24:53 +02:00
parent e1d56595c1
commit 7c30ffe2c7
28 changed files with 393 additions and 321 deletions

View File

@ -12,9 +12,8 @@
# #
import os import os
import sys import sys
sys.path.insert(0, os.path.abspath("../../"))
import sphinx_rtd_theme sys.path.insert(0, os.path.abspath("../../"))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
@ -128,15 +127,12 @@ latex_elements = {
# The paper size ("letterpaper" or "a4paper"). # The paper size ("letterpaper" or "a4paper").
# #
# "papersize": "letterpaper", # "papersize": "letterpaper",
# The font size ("10pt", "11pt" or "12pt"). # The font size ("10pt", "11pt" or "12pt").
# #
# "pointsize": "10pt", # "pointsize": "10pt",
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# "preamble": "", # "preamble": "",
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# "figure_align": "htbp", # "figure_align": "htbp",
@ -146,15 +142,21 @@ latex_elements = {
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, "prototorch.tex", "ProtoTorch Documentation", (
"Jensun Ravichandran", "manual"), master_doc,
"prototorch.tex",
"ProtoTorch Documentation",
"Jensun Ravichandran",
"manual",
),
] ]
# -- Options for manual page output --------------------------------------- # -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)] man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author],
1)]
# -- Options for Texinfo output ------------------------------------------- # -- Options for Texinfo output -------------------------------------------
@ -162,9 +164,15 @@ man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, "prototorch", "ProtoTorch Documentation", author, "prototorch", (
"Prototype-based machine learning in PyTorch.", master_doc,
"Miscellaneous"), "prototorch",
"ProtoTorch Documentation",
author,
"prototorch",
"Prototype-based machine learning in PyTorch.",
"Miscellaneous",
),
] ]
# Example configuration for intersphinx: refer to the Python standard library. # Example configuration for intersphinx: refer to the Python standard library.

View File

@ -3,13 +3,14 @@
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
# Prepare and preprocess the data # Prepare and preprocess the data
scaler = StandardScaler() scaler = StandardScaler()
@ -29,7 +30,8 @@ class Model(torch.nn.Module):
prototypes_per_class=3, prototypes_per_class=3,
nclasses=3, nclasses=3,
prototype_initializer="stratified_random", prototype_initializer="stratified_random",
data=[x_train, y_train]) data=[x_train, y_train],
)
def forward(self, x): def forward(self, x):
protos = self.proto_layer.prototypes protos = self.proto_layer.prototypes
@ -61,8 +63,10 @@ for epoch in range(70):
with torch.no_grad(): with torch.no_grad():
pred = wtac(dis, plabels) pred = wtac(dis, plabels)
correct = pred.eq(y_in.view_as(pred)).sum().item() correct = pred.eq(y_in.view_as(pred)).sum().item()
acc = 100. * correct / len(x_train) acc = 100.0 * correct / len(x_train)
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%") print(
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
)
# Take a gradient descent step # Take a gradient descent step
optimizer.zero_grad() optimizer.zero_grad()
@ -83,13 +87,15 @@ for epoch in range(70):
ax.set_ylabel("Data dimension 2") ax.set_ylabel("Data dimension 2")
cmap = "viridis" cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(protos[:, 0], ax.scatter(
protos[:, 1], protos[:, 0],
c=plabels, protos[:, 1],
cmap=cmap, c=plabels,
edgecolor="k", cmap=cmap,
marker="D", edgecolor="k",
s=50) marker="D",
s=50,
)
# Paint decision regions # Paint decision regions
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))

View File

@ -20,11 +20,13 @@ class Model(torch.nn.Module):
"""GMLVQ model as a siamese network.""" """GMLVQ model as a siamese network."""
super().__init__() super().__init__()
x, y = train_data.data, train_data.targets x, y = train_data.data, train_data.targets
self.p1 = Prototypes1D(input_dim=100, self.p1 = Prototypes1D(
prototypes_per_class=2, input_dim=100,
nclasses=2, prototypes_per_class=2,
prototype_initializer="stratified_random", nclasses=2,
data=[x, y]) prototype_initializer="stratified_random",
data=[x, y],
)
self.omega = torch.nn.Linear(in_features=100, self.omega = torch.nn.Linear(in_features=100,
out_features=100, out_features=100,
bias=False) bias=False)

View File

@ -13,8 +13,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from torchvision import transforms from torchvision import transforms
from prototorch.modules.losses import GLVQLoss
from prototorch.functions.helper import calculate_prototype_accuracy from prototorch.functions.helper import calculate_prototype_accuracy
from prototorch.modules.losses import GLVQLoss
from prototorch.modules.models import GTLVQ from prototorch.modules.models import GTLVQ
# Parameters and options # Parameters and options
@ -26,32 +27,40 @@ momentum = 0.5
log_interval = 10 log_interval = 10
cuda = "cuda:1" cuda = "cuda:1"
random_seed = 1 random_seed = 1
device = torch.device(cuda if torch.cuda.is_available() else 'cpu') device = torch.device(cuda if torch.cuda.is_available() else "cpu")
# Configures reproducability # Configures reproducability
torch.manual_seed(random_seed) torch.manual_seed(random_seed)
np.random.seed(random_seed) np.random.seed(random_seed)
# Prepare and preprocess the data # Prepare and preprocess the data
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST( train_loader = torch.utils.data.DataLoader(
'./files/', torchvision.datasets.MNIST(
train=True, "./files/",
download=True, train=True,
transform=torchvision.transforms.Compose( download=True,
[transforms.ToTensor(), transform=torchvision.transforms.Compose([
transforms.Normalize((0.1307, ), (0.3081, ))])), transforms.ToTensor(),
batch_size=batch_size_train, transforms.Normalize((0.1307, ), (0.3081, ))
shuffle=True) ]),
),
batch_size=batch_size_train,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST( test_loader = torch.utils.data.DataLoader(
'./files/', torchvision.datasets.MNIST(
train=False, "./files/",
download=True, train=False,
transform=torchvision.transforms.Compose( download=True,
[transforms.ToTensor(), transform=torchvision.transforms.Compose([
transforms.Normalize((0.1307, ), (0.3081, ))])), transforms.ToTensor(),
batch_size=batch_size_test, transforms.Normalize((0.1307, ), (0.3081, ))
shuffle=True) ]),
),
batch_size=batch_size_test,
shuffle=True,
)
# Define the GLVQ model plus appropriate feature extractor # Define the GLVQ model plus appropriate feature extractor
@ -67,25 +76,34 @@ class CNNGTLVQ(torch.nn.Module):
): ):
super(CNNGTLVQ, self).__init__() super(CNNGTLVQ, self).__init__()
#Feature Extractor - Simple CNN # Feature Extractor - Simple CNN
self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(), self.fe = nn.Sequential(
nn.Conv2d(32, 64, 3, 1), nn.ReLU(), nn.Conv2d(1, 32, 3, 1),
nn.MaxPool2d(2), nn.Dropout(0.25), nn.ReLU(),
nn.Flatten(), nn.Linear(9216, bottleneck_dim), nn.Conv2d(32, 64, 3, 1),
nn.Dropout(0.5), nn.LeakyReLU(), nn.ReLU(),
nn.LayerNorm(bottleneck_dim)) nn.MaxPool2d(2),
nn.Dropout(0.25),
nn.Flatten(),
nn.Linear(9216, bottleneck_dim),
nn.Dropout(0.5),
nn.LeakyReLU(),
nn.LayerNorm(bottleneck_dim),
)
# Forward pass of subspace and prototype initialization data through feature extractor # Forward pass of subspace and prototype initialization data through feature extractor
subspace_data = self.fe(subspace_data) subspace_data = self.fe(subspace_data)
prototype_data[0] = self.fe(prototype_data[0]) prototype_data[0] = self.fe(prototype_data[0])
# Initialization of GTLVQ # Initialization of GTLVQ
self.gtlvq = GTLVQ(num_classes, self.gtlvq = GTLVQ(
subspace_data, num_classes,
prototype_data, subspace_data,
tangent_projection_type=tangent_projection_type, prototype_data,
feature_dim=bottleneck_dim, tangent_projection_type=tangent_projection_type,
prototypes_per_class=prototypes_per_class) feature_dim=bottleneck_dim,
prototypes_per_class=prototypes_per_class,
)
def forward(self, x): def forward(self, x):
# Feature Extraction # Feature Extraction
@ -103,20 +121,24 @@ subspace_data = torch.cat(
prototype_data = next(iter(train_loader)) prototype_data = next(iter(train_loader))
# Build the CNN GTLVQ model # Build the CNN GTLVQ model
model = CNNGTLVQ(10, model = CNNGTLVQ(
subspace_data, 10,
prototype_data, subspace_data,
tangent_projection_type="local", prototype_data,
bottleneck_dim=128).to(device) tangent_projection_type="local",
bottleneck_dim=128,
).to(device)
# Optimize using SGD optimizer from `torch.optim` # Optimize using SGD optimizer from `torch.optim`
optimizer = torch.optim.Adam([{ optimizer = torch.optim.Adam(
'params': model.fe.parameters() [{
}, { "params": model.fe.parameters()
'params': model.gtlvq.parameters() }, {
}], "params": model.gtlvq.parameters()
lr=learning_rate) }],
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10) lr=learning_rate,
)
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
# Training loop # Training loop
for epoch in range(n_epochs): for epoch in range(n_epochs):
@ -139,8 +161,8 @@ for epoch in range(n_epochs):
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
acc = calculate_prototype_accuracy(distances, y_train, plabels) acc = calculate_prototype_accuracy(distances, y_train, plabels)
print( print(
f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \ f"Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
Train Acc: {acc.item():02.02f}') Train Acc: {acc.item():02.02f}")
# Test # Test
with torch.no_grad(): with torch.no_grad():
@ -154,9 +176,9 @@ for epoch in range(n_epochs):
i = torch.argmin(test_distances, 1) i = torch.argmin(test_distances, 1)
correct += torch.sum(y_test == test_plabels[i]) correct += torch.sum(y_test == test_plabels[i])
total += y_test.size(0) total += y_test.size(0)
print('Accuracy of the network on the test images: %d %%' % print("Accuracy of the network on the test images: %d %%" %
(torch.true_divide(correct, total) * 100)) (torch.true_divide(correct, total) * 100))
# Save the model # Save the model
PATH = './glvq_mnist_model.pth' PATH = "./glvq_mnist_model.pth"
torch.save(model.state_dict(), PATH) torch.save(model.state_dict(), PATH)

View File

@ -22,10 +22,12 @@ class Model(torch.nn.Module):
def __init__(self): def __init__(self):
"""Local-GMLVQ model.""" """Local-GMLVQ model."""
super().__init__() super().__init__()
self.p1 = Prototypes1D(input_dim=2, self.p1 = Prototypes1D(
prototype_distribution=[1, 2, 2], input_dim=2,
prototype_initializer="stratified_random", prototype_distribution=[1, 2, 2],
data=[x_train, y_train]) prototype_initializer="stratified_random",
data=[x_train, y_train],
)
omegas = torch.zeros(5, 2, 2) omegas = torch.zeros(5, 2, 2)
self.omegas = torch.nn.Parameter(omegas) self.omegas = torch.nn.Parameter(omegas)
eye_(self.omegas) eye_(self.omegas)
@ -76,14 +78,16 @@ for epoch in range(100):
ax.set_xlabel("Data dimension 1") ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2") ax.set_ylabel("Data dimension 2")
cmap = "viridis" cmap = "viridis"
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k') ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(protos[:, 0], ax.scatter(
protos[:, 1], protos[:, 0],
c=plabels, protos[:, 1],
cmap=cmap, c=plabels,
edgecolor='k', cmap=cmap,
marker='D', edgecolor="k",
s=50) marker="D",
s=50,
)
# Paint decision regions # Paint decision regions
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))

View File

@ -5,8 +5,6 @@
# ############################################# # #############################################
__version__ = "0.3.0-dev0" __version__ = "0.3.0-dev0"
from prototorch import datasets, functions, modules
__all_core__ = [ __all_core__ = [
"datasets", "datasets",
"functions", "functions",
@ -17,6 +15,7 @@ __all_core__ = [
# Plugin Loader # Plugin Loader
# ############################################# # #############################################
import pkgutil import pkgutil
import pkg_resources import pkg_resources
__path__ = pkgutil.extend_path(__path__, __name__) __path__ = pkgutil.extend_path(__path__, __name__)
@ -25,7 +24,8 @@ __path__ = pkgutil.extend_path(__path__, __name__)
def discover_plugins(): def discover_plugins():
return { return {
entry_point.name: entry_point.load() entry_point.name: entry_point.load()
for entry_point in pkg_resources.iter_entry_points("prototorch.plugins") for entry_point in pkg_resources.iter_entry_points(
"prototorch.plugins")
} }
@ -33,14 +33,12 @@ discovered_plugins = discover_plugins()
locals().update(discovered_plugins) locals().update(discovered_plugins)
# Generate combines __version__ and __all__ # Generate combines __version__ and __all__
version_plugins = "\n".join( version_plugins = "\n".join([
[ "- " + name + ": v" + plugin.__version__
"- " + name + ": v" + plugin.__version__ for name, plugin in discovered_plugins.items()
for name, plugin in discovered_plugins.items() ])
]
)
if version_plugins != "": if version_plugins != "":
version_plugins = "\nPlugins: \n" + version_plugins version_plugins = "\nPlugins: \n" + version_plugins
version = "core: v" + __version__ + version_plugins version = "core: v" + __version__ + version_plugins
__all__ = __all_core__ + list(discovered_plugins.keys()) __all__ = __all_core__ + list(discovered_plugins.keys())

View File

@ -3,5 +3,5 @@
from .tecator import Tecator from .tecator import Tecator
__all__ = [ __all__ = [
'Tecator', "Tecator",
] ]

View File

@ -52,7 +52,8 @@ class Tecator(ProtoDataset):
""" """
_resources = [ _resources = [
("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0", "ba5607c580d0f91bb27dc29d13c2f8df"), ("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0",
"ba5607c580d0f91bb27dc29d13c2f8df"),
] # (google_storage_id, md5hash) ] # (google_storage_id, md5hash)
classes = ["0 - low_fat", "1 - high_fat"] classes = ["0 - low_fat", "1 - high_fat"]
@ -74,15 +75,15 @@ class Tecator(ProtoDataset):
print("Downloading...") print("Downloading...")
for fileid, md5 in self._resources: for fileid, md5 in self._resources:
filename = "tecator.npz" filename = "tecator.npz"
download_file_from_google_drive( download_file_from_google_drive(fileid,
fileid, root=self.raw_folder, filename=filename, md5=md5 root=self.raw_folder,
) filename=filename,
md5=md5)
if self.verbose: if self.verbose:
print("Processing...") print("Processing...")
with np.load( with np.load(os.path.join(self.raw_folder, "tecator.npz"),
os.path.join(self.raw_folder, "tecator.npz"), allow_pickle=False allow_pickle=False) as f:
) as f:
x_train, y_train = f["x_train"], f["y_train"] x_train, y_train = f["x_train"], f["y_train"]
x_test, y_test = f["x_test"], f["y_test"] x_test, y_test = f["x_test"], f["y_test"]
training_set = [ training_set = [
@ -94,9 +95,11 @@ class Tecator(ProtoDataset):
torch.tensor(y_test), torch.tensor(y_test),
] ]
with open(os.path.join(self.processed_folder, self.training_file), "wb") as f: with open(os.path.join(self.processed_folder, self.training_file),
"wb") as f:
torch.save(training_set, f) torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self.test_file), "wb") as f: with open(os.path.join(self.processed_folder, self.test_file),
"wb") as f:
torch.save(test_set, f) torch.save(test_set, f)
if self.verbose: if self.verbose:

View File

@ -4,9 +4,9 @@ from .activations import identity, sigmoid_beta, swish_beta
from .competitions import knnc, wtac from .competitions import knnc, wtac
__all__ = [ __all__ = [
'identity', "identity",
'sigmoid_beta', "sigmoid_beta",
'swish_beta', "swish_beta",
'knnc', "knnc",
'wtac', "wtac",
] ]

View File

@ -61,4 +61,4 @@ def get_activation(funcname):
return funcname return funcname
if funcname in ACTIVATIONS: if funcname in ACTIVATIONS:
return ACTIVATIONS.get(funcname) return ACTIVATIONS.get(funcname)
raise NameError(f'Activation {funcname} was not found.') raise NameError(f"Activation {funcname} was not found.")

View File

@ -12,7 +12,7 @@ def stratified_min(distances, labels):
return distances return distances
batch_size = distances.size()[0] batch_size = distances.size()[0]
winning_distances = torch.zeros(nclasses, batch_size) winning_distances = torch.zeros(nclasses, batch_size)
inf = torch.full_like(distances.T, fill_value=float('inf')) inf = torch.full_like(distances.T, fill_value=float("inf"))
# distances_to_wpluses = torch.where(matcher, distances, inf) # distances_to_wpluses = torch.where(matcher, distances, inf)
for i, cl in enumerate(clabels): for i, cl in enumerate(clabels):
# cdists = distances.T[labels == cl] # cdists = distances.T[labels == cl]

View File

@ -1,12 +1,10 @@
"""ProtoTorch distance functions.""" """ProtoTorch distance functions."""
import torch
from prototorch.functions.helper import (
equal_int_shape,
_int_and_mixed_shape,
_check_shapes,
)
import numpy as np import numpy as np
import torch
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
equal_int_shape)
def squared_euclidean_distance(x, y): def squared_euclidean_distance(x, y):

View File

@ -23,7 +23,7 @@ def predict_label(y_pred, plabels):
def mixed_shape(inputs): def mixed_shape(inputs):
if not torch.is_tensor(inputs): if not torch.is_tensor(inputs):
raise ValueError('Input must be a tensor.') raise ValueError("Input must be a tensor.")
else: else:
int_shape = list(inputs.shape) int_shape = list(inputs.shape)
# sometimes int_shape returns mixed integer types # sometimes int_shape returns mixed integer types
@ -39,11 +39,11 @@ def mixed_shape(inputs):
def equal_int_shape(shape_1, shape_2): def equal_int_shape(shape_1, shape_2):
if not isinstance(shape_1, if not isinstance(shape_1,
(tuple, list)) or not isinstance(shape_2, (tuple, list)): (tuple, list)) or not isinstance(shape_2, (tuple, list)):
raise ValueError('Input shapes must list or tuple.') raise ValueError("Input shapes must list or tuple.")
for shape in [shape_1, shape_2]: for shape in [shape_1, shape_2]:
if not all([isinstance(x, int) or x is None for x in shape]): if not all([isinstance(x, int) or x is None for x in shape]):
raise ValueError( raise ValueError(
'Input shapes must be list or tuple of int and None values.') "Input shapes must be list or tuple of int and None values.")
if len(shape_1) != len(shape_2): if len(shape_1) != len(shape_2):
return False return False

View File

@ -104,4 +104,4 @@ def get_initializer(funcname):
return funcname return funcname
if funcname in INITIALIZERS: if funcname in INITIALIZERS:
return INITIALIZERS.get(funcname) return INITIALIZERS.get(funcname)
raise NameError(f'Initializer {funcname} was not found.') raise NameError(f"Initializer {funcname} was not found.")

View File

@ -11,7 +11,7 @@ def _get_dp_dm(distances, targets, plabels):
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses) matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
not_matcher = torch.bitwise_not(matcher) not_matcher = torch.bitwise_not(matcher)
inf = torch.full_like(distances, fill_value=float('inf')) inf = torch.full_like(distances, fill_value=float("inf"))
d_matching = torch.where(matcher, distances, inf) d_matching = torch.where(matcher, distances, inf)
d_unmatching = torch.where(not_matcher, distances, inf) d_unmatching = torch.where(not_matcher, distances, inf)
dp = torch.min(d_matching, dim=1, keepdim=True).values dp = torch.min(d_matching, dim=1, keepdim=True).values

View File

@ -1,7 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function from __future__ import absolute_import, division, print_function
from __future__ import absolute_import
from __future__ import division
import torch import torch

View File

@ -3,5 +3,5 @@
from .prototypes import Prototypes1D from .prototypes import Prototypes1D
__all__ = [ __all__ = [
'Prototypes1D', "Prototypes1D",
] ]

View File

@ -7,7 +7,7 @@ from prototorch.functions.losses import glvq_loss
class GLVQLoss(torch.nn.Module): class GLVQLoss(torch.nn.Module):
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs): def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.margin = margin self.margin = margin
self.squashing = get_activation(squashing) self.squashing = get_activation(squashing)
@ -37,4 +37,4 @@ class NeuralGasEnergy(torch.nn.Module):
@staticmethod @staticmethod
def _nghood_fn(rankings, lm): def _nghood_fn(rankings, lm):
return torch.exp(-rankings / lm) return torch.exp(-rankings / lm)

View File

@ -1,9 +1,11 @@
from torch import nn
import torch import torch
from prototorch.modules.prototypes import Prototypes1D from torch import nn
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization from prototorch.functions.distances import (euclidean_distance_matrix,
tangent_distance)
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D
class GTLVQ(nn.Module): class GTLVQ(nn.Module):
@ -71,7 +73,7 @@ class GTLVQ(nn.Module):
subspace_data=None, subspace_data=None,
prototype_data=None, prototype_data=None,
subspace_size=256, subspace_size=256,
tangent_projection_type='local', tangent_projection_type="local",
prototypes_per_class=2, prototypes_per_class=2,
feature_dim=256, feature_dim=256,
): ):
@ -82,37 +84,39 @@ class GTLVQ(nn.Module):
self.feature_dim = feature_dim self.feature_dim = feature_dim
if subspace_data is None: if subspace_data is None:
raise ValueError('Init Data must be specified!') raise ValueError("Init Data must be specified!")
self.tpt = tangent_projection_type self.tpt = tangent_projection_type
with torch.no_grad(): with torch.no_grad():
if self.tpt == 'local' or self.tpt == 'local_proj': if self.tpt == "local" or self.tpt == "local_proj":
self.init_local_subspace(subspace_data) self.init_local_subspace(subspace_data)
elif self.tpt == 'global': elif self.tpt == "global":
self.init_gobal_subspace(subspace_data, subspace_size) self.init_gobal_subspace(subspace_data, subspace_size)
else: else:
self.subspaces = None self.subspaces = None
# Hypothesis-Margin-Classifier # Hypothesis-Margin-Classifier
self.cls = Prototypes1D(input_dim=feature_dim, self.cls = Prototypes1D(
prototypes_per_class=prototypes_per_class, input_dim=feature_dim,
nclasses=num_classes, prototypes_per_class=prototypes_per_class,
prototype_initializer='stratified_mean', nclasses=num_classes,
data=prototype_data) prototype_initializer="stratified_mean",
data=prototype_data,
)
def forward(self, x): def forward(self, x):
# Tangent Projection # Tangent Projection
if self.tpt == 'local_proj': if self.tpt == "local_proj":
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos, x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2) 1).unsqueeze(2))
dis, proj_x = self.local_tangent_projection(x_conform) dis, proj_x = self.local_tangent_projection(x_conform)
proj_x = proj_x.reshape(x.shape[0] * self.num_protos, proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
self.feature_dim) self.feature_dim)
return proj_x, dis return proj_x, dis
elif self.tpt == "local": elif self.tpt == "local":
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos, x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
1).unsqueeze(2) 1).unsqueeze(2))
dis = tangent_distance(x_conform, self.cls.prototypes, dis = tangent_distance(x_conform, self.cls.prototypes,
self.subspaces) self.subspaces)
elif self.tpt == "gloabl": elif self.tpt == "gloabl":
@ -127,25 +131,27 @@ class GTLVQ(nn.Module):
_, _, v = torch.svd(data) _, _, v = torch.svd(data)
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = subspace[:, :num_subspaces] subspaces = subspace[:, :num_subspaces]
self.subspaces = torch.nn.Parameter( self.subspaces = (torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True) subspaces).clone().detach().requires_grad_(True))
def init_local_subspace(self, data): def init_local_subspace(self, data):
_, _, v = torch.svd(data) _, _, v = torch.svd(data)
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
subspaces = inital_projector.unsqueeze(0).repeat_interleave( subspaces = inital_projector.unsqueeze(0).repeat_interleave(
self.num_protos, 0) self.num_protos, 0)
self.subspaces = torch.nn.Parameter( self.subspaces = (torch.nn.Parameter(
subspaces).clone().detach().requires_grad_(True) subspaces).clone().detach().requires_grad_(True))
def global_tangent_distances(self, x): def global_tangent_distances(self, x):
# Tangent Projection # Tangent Projection
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces x, projected_prototypes = (
x @ self.subspaces,
self.cls.prototypes @ self.subspaces,
)
# Euclidean Distance # Euclidean Distance
return euclidean_distance_matrix(x, projected_prototypes) return euclidean_distance_matrix(x, projected_prototypes)
def local_tangent_projection(self, def local_tangent_projection(self, signals):
signals):
# Note: subspaces is always assumed as transposed and must be orthogonal! # Note: subspaces is always assumed as transposed and must be orthogonal!
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN # shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
# shape(protos): proto_number x dim1 x dim2 x ... x dimN # shape(protos): proto_number x dim1 x dim2 x ... x dimN
@ -183,8 +189,7 @@ class GTLVQ(nn.Module):
def orthogonalize_subspace(self): def orthogonalize_subspace(self):
if self.subspaces is not None: if self.subspaces is not None:
with torch.no_grad(): with torch.no_grad():
ortho_subpsaces = orthogonalization( ortho_subpsaces = (orthogonalization(self.subspaces)
self.subspaces if self.tpt == "global" else
) if self.tpt == 'global' else torch.nn.init.orthogonal_( torch.nn.init.orthogonal_(self.subspaces))
self.subspaces)
self.subspaces.copy_(ortho_subpsaces) self.subspaces.copy_(ortho_subpsaces)

View File

@ -29,14 +29,16 @@ class Prototypes1D(_Prototypes):
TODO Complete this doc-string. TODO Complete this doc-string.
""" """
def __init__(self, def __init__(
prototypes_per_class=1, self,
prototype_initializer="ones", prototypes_per_class=1,
prototype_distribution=None, prototype_initializer="ones",
data=None, prototype_distribution=None,
dtype=torch.float32, data=None,
one_hot_labels=False, dtype=torch.float32,
**kwargs): one_hot_labels=False,
**kwargs,
):
# Convert tensors to python lists before processing # Convert tensors to python lists before processing
if prototype_distribution is not None: if prototype_distribution is not None:

View File

@ -1 +0,0 @@
from .colors import color_scheme, get_legend_handles

View File

@ -1,13 +1,13 @@
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid.""" """Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
from typing import Dict, List
from collections import defaultdict from collections import defaultdict
from typing import Dict, List
from matplotlib.figure import Figure
from matplotlib.artist import Artist
from matplotlib.animation import ArtistAnimation from matplotlib.animation import ArtistAnimation
from matplotlib.artist import Artist
from matplotlib.figure import Figure
__version__ = '0.2.0' __version__ = "0.2.0"
class Camera: class Camera:
@ -19,7 +19,7 @@ class Camera:
self._offsets: Dict[str, Dict[int, int]] = { self._offsets: Dict[str, Dict[int, int]] = {
k: defaultdict(int) k: defaultdict(int)
for k in for k in
['collections', 'patches', 'lines', 'texts', 'artists', 'images'] ["collections", "patches", "lines", "texts", "artists", "images"]
} }
self._photos: List[List[Artist]] = [] self._photos: List[List[Artist]] = []

View File

@ -1,13 +1,14 @@
"""ProtoFlow color utilities.""" """ProtoFlow color utilities."""
from matplotlib import cm
from matplotlib.colors import Normalize
from matplotlib.colors import to_hex
from matplotlib.colors import to_rgb
import matplotlib.lines as mlines import matplotlib.lines as mlines
from matplotlib import cm
from matplotlib.colors import Normalize, to_hex, to_rgb
def color_scheme(n, cmap="viridis", form="hex", tikz=False, def color_scheme(n,
cmap="viridis",
form="hex",
tikz=False,
zero_indexed=False): zero_indexed=False):
"""Return *n* colors from the color scheme. """Return *n* colors from the color scheme.
@ -57,13 +58,16 @@ def get_legend_handles(labels, marker="dots", zero_indexed=False):
zero_indexed=zero_indexed) zero_indexed=zero_indexed)
for label, color in zip(labels, colors.values()): for label, color in zip(labels, colors.values()):
if marker == "dots": if marker == "dots":
handle = mlines.Line2D([], [], handle = mlines.Line2D(
color="white", [],
markerfacecolor=color, [],
marker="o", color="white",
markersize=10, markerfacecolor=color,
markeredgecolor="k", marker="o",
label=label) markersize=10,
markeredgecolor="k",
label=label,
)
else: else:
handle = mlines.Line2D([], [], handle = mlines.Line2D([], [],
color=color, color=color,

View File

@ -11,17 +11,17 @@ import numpy as np
def progressbar(title, value, end, bar_width=20): def progressbar(title, value, end, bar_width=20):
percent = float(value) / end percent = float(value) / end
arrow = '=' * int(round(percent * bar_width) - 1) + '>' arrow = "=" * int(round(percent * bar_width) - 1) + ">"
spaces = '.' * (bar_width - len(arrow)) spaces = "." * (bar_width - len(arrow))
sys.stdout.write('\r{}: [{}] {}%'.format(title, arrow + spaces, sys.stdout.write("\r{}: [{}] {}%".format(title, arrow + spaces,
int(round(percent * 100)))) int(round(percent * 100))))
sys.stdout.flush() sys.stdout.flush()
if percent == 1.0: if percent == 1.0:
print() print()
def prettify_string(inputs, start='', sep=' ', end='\n'): def prettify_string(inputs, start="", sep=" ", end="\n"):
outputs = start + ' '.join(inputs.split()) + end outputs = start + " ".join(inputs.split()) + end
return outputs return outputs
@ -29,22 +29,22 @@ def pretty_print(inputs):
print(prettify_string(inputs)) print(prettify_string(inputs))
def writelog(self, *logs, logdir='./logs', logfile='run.txt'): def writelog(self, *logs, logdir="./logs", logfile="run.txt"):
f = os.path.join(logdir, logfile) f = os.path.join(logdir, logfile)
with open(f, 'a+') as fh: with open(f, "a+") as fh:
for log in logs: for log in logs:
fh.write(log) fh.write(log)
fh.write('\n') fh.write("\n")
def start_tensorboard(self, logdir='./logs'): def start_tensorboard(self, logdir="./logs"):
cmd = f'tensorboard --logdir={logdir} --port=6006' cmd = f"tensorboard --logdir={logdir} --port=6006"
os.system(cmd) os.system(cmd)
def make_directory(save_dir): def make_directory(save_dir):
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
print(f'Making directory {save_dir}.') print(f"Making directory {save_dir}.")
os.mkdir(save_dir) os.mkdir(save_dir)
@ -52,36 +52,36 @@ def make_gif(filenames, duration, output_file=None):
try: try:
import imageio import imageio
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
print('Please install Protoflow with [other] extra requirements.') print("Please install Protoflow with [other] extra requirements.")
raise (e) raise (e)
images = list() images = list()
for filename in filenames: for filename in filenames:
images.append(imageio.imread(filename)) images.append(imageio.imread(filename))
if not output_file: if not output_file:
output_file = f'makegif.gif' output_file = f"makegif.gif"
if images: if images:
imageio.mimwrite(output_file, images, duration=duration) imageio.mimwrite(output_file, images, duration=duration)
def gif_from_dir(directory, def gif_from_dir(directory,
duration, duration,
prefix='', prefix="",
output_file=None, output_file=None,
verbose=True): verbose=True):
images = os.listdir(directory) images = os.listdir(directory)
if verbose: if verbose:
print(f'Making gif from {len(images)} images under {directory}.') print(f"Making gif from {len(images)} images under {directory}.")
filenames = list() filenames = list()
# Sort images # Sort images
images = sorted( images = sorted(
images, images,
key=lambda img: int(os.path.splitext(img)[0].replace(prefix, ''))) key=lambda img: int(os.path.splitext(img)[0].replace(prefix, "")))
for image in images: for image in images:
fname = os.path.join(directory, image) fname = os.path.join(directory, image)
filenames.append(fname) filenames.append(fname)
if not output_file: if not output_file:
output_file = os.path.join(directory, 'makegif.gif') output_file = os.path.join(directory, "makegif.gif")
make_gif(filenames=filenames, duration=duration, output_file=output_file) make_gif(filenames=filenames, duration=duration, output_file=output_file)
@ -95,12 +95,12 @@ def predict_and_score(clf,
x_test, x_test,
y_test, y_test,
verbose=False, verbose=False,
title='Test accuracy'): title="Test accuracy"):
y_pred = clf.predict(x_test) y_pred = clf.predict(x_test)
accuracy = np.sum(y_test == y_pred) accuracy = np.sum(y_test == y_pred)
normalized_acc = accuracy / float(len(y_test)) normalized_acc = accuracy / float(len(y_test))
if verbose: if verbose:
print(f'{title}: {normalized_acc * 100:06.04f}%') print(f"{title}: {normalized_acc * 100:06.04f}%")
return normalized_acc return normalized_acc
@ -124,6 +124,7 @@ def replace_in(arr, replacement_dict, inplace=False):
new_arr = arr new_arr = arr
else: else:
import copy import copy
new_arr = copy.deepcopy(arr) new_arr = copy.deepcopy(arr)
for k, v in replacement_dict.items(): for k, v in replacement_dict.items():
new_arr[arr == k] = v new_arr[arr == k] = v
@ -135,7 +136,7 @@ def train_test_split(data, train=0.7, val=0.15, shuffle=None, return_xy=False):
preserve the class distribution in subsamples of the dataset. preserve the class distribution in subsamples of the dataset.
""" """
if train + val > 1.0: if train + val > 1.0:
raise ValueError('Invalid split values for train and val.') raise ValueError("Invalid split values for train and val.")
Y = data[:, -1] Y = data[:, -1]
labels = set(Y) labels = set(Y)
hist = dict() hist = dict()
@ -183,20 +184,20 @@ def train_test_split(data, train=0.7, val=0.15, shuffle=None, return_xy=False):
return train_data, val_data, test_data return train_data, val_data, test_data
def class_histogram(data, title='Untitled'): def class_histogram(data, title="Untitled"):
plt.figure(title) plt.figure(title)
plt.clf() plt.clf()
plt.title(title) plt.title(title)
dist, counts = np.unique(data[:, -1], return_counts=True) dist, counts = np.unique(data[:, -1], return_counts=True)
plt.bar(dist, counts) plt.bar(dist, counts)
plt.xticks(dist) plt.xticks(dist)
print('Call matplotlib.pyplot.show() to see the plot.') print("Call matplotlib.pyplot.show() to see the plot.")
def ntimer(n=10): def ntimer(n=10):
"""Wraps a function which wraps another function to time it.""" """Wraps a function which wraps another function to time it."""
if n < 1: if n < 1:
raise (Exception(f'Invalid n = {n} given.')) raise (Exception(f"Invalid n = {n} given."))
def timer(func): def timer(func):
"""Wraps `func` with a timer and returns the wrapped `func`.""" """Wraps `func` with a timer and returns the wrapped `func`."""
@ -207,7 +208,7 @@ def ntimer(n=10):
rv = func(*args, **kwargs) rv = func(*args, **kwargs)
after = time() after = time()
elapsed = after - before elapsed = after - before
print(f'Elapsed: {elapsed*1e3:02.02f} ms') print(f"Elapsed: {elapsed*1e3:02.02f} ms")
return rv return rv
return wrapper return wrapper
@ -228,15 +229,15 @@ def memoize(verbose=True):
t = (pickle.dumps(args), pickle.dumps(kwargs)) t = (pickle.dumps(args), pickle.dumps(kwargs))
if t not in cache: if t not in cache:
if verbose: if verbose:
print(f'Adding NEW rv {func.__name__}{args}{kwargs} ' print(f"Adding NEW rv {func.__name__}{args}{kwargs} "
'to cache.') "to cache.")
cache[t] = func(*args, **kwargs) cache[t] = func(*args, **kwargs)
else: else:
if verbose: if verbose:
print(f'Using OLD rv {func.__name__}{args}{kwargs} ' print(f"Using OLD rv {func.__name__}{args}{kwargs} "
'from cache.') "from cache.")
return cache[t] return cache[t]
return wrapper return wrapper
return memoizer return memoizer

View File

@ -8,8 +8,7 @@
ProtoTorch Core Package ProtoTorch Core Package
""" """
from setuptools import setup from setuptools import find_packages, setup
from setuptools import find_packages
PROJECT_URL = "https://github.com/si-cim/prototorch" PROJECT_URL = "https://github.com/si-cim/prototorch"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git" DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"

View File

@ -12,26 +12,26 @@ from prototorch.datasets import abstract, tecator
class TestAbstract(unittest.TestCase): class TestAbstract(unittest.TestCase):
def test_getitem(self): def test_getitem(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
abstract.Dataset('./artifacts')[0] abstract.Dataset("./artifacts")[0]
def test_len(self): def test_len(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
len(abstract.Dataset('./artifacts')) len(abstract.Dataset("./artifacts"))
class TestProtoDataset(unittest.TestCase): class TestProtoDataset(unittest.TestCase):
def test_getitem(self): def test_getitem(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
abstract.ProtoDataset('./artifacts')[0] abstract.ProtoDataset("./artifacts")[0]
def test_download(self): def test_download(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
abstract.ProtoDataset('./artifacts').download() abstract.ProtoDataset("./artifacts").download()
class TestTecator(unittest.TestCase): class TestTecator(unittest.TestCase):
def setUp(self): def setUp(self):
self.artifacts_dir = './artifacts/Tecator' self.artifacts_dir = "./artifacts/Tecator"
self._remove_artifacts() self._remove_artifacts()
def _remove_artifacts(self): def _remove_artifacts(self):
@ -39,23 +39,23 @@ class TestTecator(unittest.TestCase):
shutil.rmtree(self.artifacts_dir) shutil.rmtree(self.artifacts_dir)
def test_download_false(self): def test_download_false(self):
rootdir = self.artifacts_dir.rpartition('/')[0] rootdir = self.artifacts_dir.rpartition("/")[0]
self._remove_artifacts() self._remove_artifacts()
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
_ = tecator.Tecator(rootdir, download=False) _ = tecator.Tecator(rootdir, download=False)
def test_download_caching(self): def test_download_caching(self):
rootdir = self.artifacts_dir.rpartition('/')[0] rootdir = self.artifacts_dir.rpartition("/")[0]
_ = tecator.Tecator(rootdir, download=True, verbose=False) _ = tecator.Tecator(rootdir, download=True, verbose=False)
_ = tecator.Tecator(rootdir, download=False, verbose=False) _ = tecator.Tecator(rootdir, download=False, verbose=False)
def test_repr(self): def test_repr(self):
rootdir = self.artifacts_dir.rpartition('/')[0] rootdir = self.artifacts_dir.rpartition("/")[0]
train = tecator.Tecator(rootdir, download=True, verbose=True) train = tecator.Tecator(rootdir, download=True, verbose=True)
self.assertTrue('Split: Train' in train.__repr__()) self.assertTrue("Split: Train" in train.__repr__())
def test_download_train(self): def test_download_train(self):
rootdir = self.artifacts_dir.rpartition('/')[0] rootdir = self.artifacts_dir.rpartition("/")[0]
train = tecator.Tecator(root=rootdir, train = tecator.Tecator(root=rootdir,
train=True, train=True,
download=True, download=True,
@ -67,7 +67,7 @@ class TestTecator(unittest.TestCase):
self.assertEqual(x_train.shape[1], 100) self.assertEqual(x_train.shape[1], 100)
def test_download_test(self): def test_download_test(self):
rootdir = self.artifacts_dir.rpartition('/')[0] rootdir = self.artifacts_dir.rpartition("/")[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False) test = tecator.Tecator(root=rootdir, train=False, verbose=False)
x_test, y_test = test.data, test.targets x_test, y_test = test.data, test.targets
self.assertEqual(x_test.shape[0], 71) self.assertEqual(x_test.shape[0], 71)
@ -75,19 +75,19 @@ class TestTecator(unittest.TestCase):
self.assertEqual(x_test.shape[1], 100) self.assertEqual(x_test.shape[1], 100)
def test_class_to_idx(self): def test_class_to_idx(self):
rootdir = self.artifacts_dir.rpartition('/')[0] rootdir = self.artifacts_dir.rpartition("/")[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False) test = tecator.Tecator(root=rootdir, train=False, verbose=False)
_ = test.class_to_idx _ = test.class_to_idx
def test_getitem(self): def test_getitem(self):
rootdir = self.artifacts_dir.rpartition('/')[0] rootdir = self.artifacts_dir.rpartition("/")[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False) test = tecator.Tecator(root=rootdir, train=False, verbose=False)
x, y = test[0] x, y = test[0]
self.assertEqual(x.shape[0], 100) self.assertEqual(x.shape[0], 100)
self.assertIsInstance(y, int) self.assertIsInstance(y, int)
def test_loadable_with_dataloader(self): def test_loadable_with_dataloader(self):
rootdir = self.artifacts_dir.rpartition('/')[0] rootdir = self.artifacts_dir.rpartition("/")[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False) test = tecator.Tecator(root=rootdir, train=False, verbose=False)
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True) _ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)

View File

@ -11,7 +11,7 @@ from prototorch.functions import (activations, competitions, distances,
class TestActivations(unittest.TestCase): class TestActivations(unittest.TestCase):
def setUp(self): def setUp(self):
self.flist = ['identity', 'sigmoid_beta', 'swish_beta'] self.flist = ["identity", "sigmoid_beta", "swish_beta"]
self.x = torch.randn(1024, 1) self.x = torch.randn(1024, 1)
def test_registry(self): def test_registry(self):
@ -39,7 +39,7 @@ class TestActivations(unittest.TestCase):
self.assertEqual(1, f(1)) self.assertEqual(1, f(1))
def test_unknown_deserialization(self): def test_unknown_deserialization(self):
for funcname in ['blubb', 'foobar']: for funcname in ["blubb", "foobar"]:
with self.assertRaises(NameError): with self.assertRaises(NameError):
_ = activations.get_activation(funcname) _ = activations.get_activation(funcname)
@ -76,7 +76,7 @@ class TestCompetitions(unittest.TestCase):
pass pass
def test_wtac(self): def test_wtac(self):
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]]) d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
labels = torch.tensor([0, 1, 2, 3]) labels = torch.tensor([0, 1, 2, 3])
actual = competitions.wtac(d, labels) actual = competitions.wtac(d, labels)
desired = torch.tensor([2, 0]) desired = torch.tensor([2, 0])
@ -86,7 +86,7 @@ class TestCompetitions(unittest.TestCase):
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_wtac_unequal_dist(self): def test_wtac_unequal_dist(self):
d = torch.tensor([[2., 3., 4.], [2., 3., 1.]]) d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
labels = torch.tensor([0, 1, 1]) labels = torch.tensor([0, 1, 1])
actual = competitions.wtac(d, labels) actual = competitions.wtac(d, labels)
desired = torch.tensor([0, 1]) desired = torch.tensor([0, 1])
@ -96,7 +96,7 @@ class TestCompetitions(unittest.TestCase):
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_wtac_one_hot(self): def test_wtac_one_hot(self):
d = torch.tensor([[1.99, 3.01], [3., 2.01]]) d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
labels = torch.tensor([[0, 1], [1, 0]]) labels = torch.tensor([[0, 1], [1, 0]])
actual = competitions.wtac(d, labels) actual = competitions.wtac(d, labels)
desired = torch.tensor([[0, 1], [1, 0]]) desired = torch.tensor([[0, 1], [1, 0]])
@ -106,38 +106,38 @@ class TestCompetitions(unittest.TestCase):
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_stratified_min(self): def test_stratified_min(self):
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]]) d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
labels = torch.tensor([0, 0, 1, 2]) labels = torch.tensor([0, 0, 1, 2])
actual = competitions.stratified_min(d, labels) actual = competitions.stratified_min(d, labels)
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]]) desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_stratified_min_one_hot(self): def test_stratified_min_one_hot(self):
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]]) d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
labels = torch.tensor([0, 0, 1, 2]) labels = torch.tensor([0, 0, 1, 2])
labels = torch.eye(3)[labels] labels = torch.eye(3)[labels]
actual = competitions.stratified_min(d, labels) actual = competitions.stratified_min(d, labels)
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]]) desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_stratified_min_simple(self): def test_stratified_min_simple(self):
d = torch.tensor([[0., 2., 3.], [8., 0, 1]]) d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
labels = torch.tensor([0, 1, 2]) labels = torch.tensor([0, 1, 2])
actual = competitions.stratified_min(d, labels) actual = competitions.stratified_min(d, labels)
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]]) desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_knnc_k1(self): def test_knnc_k1(self):
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]]) d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
labels = torch.tensor([0, 1, 2, 3]) labels = torch.tensor([0, 1, 2, 3])
actual = competitions.knnc(d, labels, k=torch.tensor([1])) actual = competitions.knnc(d, labels, k=torch.tensor([1]))
desired = torch.tensor([2, 0]) desired = torch.tensor([2, 0])
@ -194,12 +194,12 @@ class TestDistances(unittest.TestCase):
desired = torch.empty(self.nx, self.ny) desired = torch.empty(self.nx, self.ny)
for i in range(self.nx): for i in range(self.nx):
for j in range(self.ny): for j in range(self.ny):
desired[i][j] = torch.nn.functional.pairwise_distance( desired[i][j] = (torch.nn.functional.pairwise_distance(
self.x[i].reshape(1, -1), self.x[i].reshape(1, -1),
self.y[j].reshape(1, -1), self.y[j].reshape(1, -1),
p=2, p=2,
keepdim=False, keepdim=False,
)**2 )**2)
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=2) decimal=2)
@ -254,14 +254,14 @@ class TestDistances(unittest.TestCase):
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_lpnorm_pinf(self): def test_lpnorm_pinf(self):
actual = distances.lpnorm_distance(self.x, self.y, p=float('inf')) actual = distances.lpnorm_distance(self.x, self.y, p=float("inf"))
desired = torch.empty(self.nx, self.ny) desired = torch.empty(self.nx, self.ny)
for i in range(self.nx): for i in range(self.nx):
for j in range(self.ny): for j in range(self.ny):
desired[i][j] = torch.nn.functional.pairwise_distance( desired[i][j] = torch.nn.functional.pairwise_distance(
self.x[i].reshape(1, -1), self.x[i].reshape(1, -1),
self.y[j].reshape(1, -1), self.y[j].reshape(1, -1),
p=float('inf'), p=float("inf"),
keepdim=False, keepdim=False,
) )
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
@ -275,12 +275,12 @@ class TestDistances(unittest.TestCase):
desired = torch.empty(self.nx, self.ny) desired = torch.empty(self.nx, self.ny)
for i in range(self.nx): for i in range(self.nx):
for j in range(self.ny): for j in range(self.ny):
desired[i][j] = torch.nn.functional.pairwise_distance( desired[i][j] = (torch.nn.functional.pairwise_distance(
self.x[i].reshape(1, -1), self.x[i].reshape(1, -1),
self.y[j].reshape(1, -1), self.y[j].reshape(1, -1),
p=2, p=2,
keepdim=False, keepdim=False,
)**2 )**2)
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=2) decimal=2)
@ -293,12 +293,12 @@ class TestDistances(unittest.TestCase):
desired = torch.empty(self.nx, self.ny) desired = torch.empty(self.nx, self.ny)
for i in range(self.nx): for i in range(self.nx):
for j in range(self.ny): for j in range(self.ny):
desired[i][j] = torch.nn.functional.pairwise_distance( desired[i][j] = (torch.nn.functional.pairwise_distance(
self.x[i].reshape(1, -1), self.x[i].reshape(1, -1),
self.y[j].reshape(1, -1), self.y[j].reshape(1, -1),
p=2, p=2,
keepdim=False, keepdim=False,
)**2 )**2)
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=2) decimal=2)
@ -311,8 +311,12 @@ class TestDistances(unittest.TestCase):
class TestInitializers(unittest.TestCase): class TestInitializers(unittest.TestCase):
def setUp(self): def setUp(self):
self.flist = [ self.flist = [
'zeros', 'ones', 'rand', 'randn', 'stratified_mean', "zeros",
'stratified_random' "ones",
"rand",
"randn",
"stratified_mean",
"stratified_random",
] ]
self.x = torch.tensor( self.x = torch.tensor(
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]], [[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
@ -340,7 +344,7 @@ class TestInitializers(unittest.TestCase):
self.assertEqual(1, f(1)) self.assertEqual(1, f(1))
def test_unknown_deserialization(self): def test_unknown_deserialization(self):
for funcname in ['blubb', 'foobar']: for funcname in ["blubb", "foobar"]:
with self.assertRaises(NameError): with self.assertRaises(NameError):
_ = initializers.get_initializer(funcname) _ = initializers.get_initializer(funcname)
@ -383,7 +387,7 @@ class TestInitializers(unittest.TestCase):
def test_stratified_mean_equal1(self): def test_stratified_mean_equal1(self):
pdist = torch.tensor([1, 1]) pdist = torch.tensor([1, 1])
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False) actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]]) desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
@ -393,7 +397,7 @@ class TestInitializers(unittest.TestCase):
pdist = torch.tensor([1, 1]) pdist = torch.tensor([1, 1])
actual, _ = initializers.stratified_random(self.x, self.y, pdist, actual, _ = initializers.stratified_random(self.x, self.y, pdist,
False) False)
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.]]) desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
@ -402,8 +406,8 @@ class TestInitializers(unittest.TestCase):
def test_stratified_mean_equal2(self): def test_stratified_mean_equal2(self):
pdist = torch.tensor([2, 2]) pdist = torch.tensor([2, 2])
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False) actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.], desired = torch.tensor([[5.0, 5.0, 5.0], [5.0, 5.0, 5.0],
[1., 1., 1.]]) [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
@ -413,8 +417,8 @@ class TestInitializers(unittest.TestCase):
pdist = torch.tensor([2, 2]) pdist = torch.tensor([2, 2])
actual, _ = initializers.stratified_random(self.x, self.y, pdist, actual, _ = initializers.stratified_random(self.x, self.y, pdist,
False) False)
desired = torch.tensor([[0., -1., -2.], [0., -1., -2.], [0., 0., 0.], desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, -1.0, -2.0],
[0., 0., 0.]]) [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
@ -423,8 +427,8 @@ class TestInitializers(unittest.TestCase):
def test_stratified_mean_unequal(self): def test_stratified_mean_unequal(self):
pdist = torch.tensor([1, 3]) pdist = torch.tensor([1, 3])
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False) actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.], desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
[1., 1., 1.]]) [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
@ -434,8 +438,8 @@ class TestInitializers(unittest.TestCase):
pdist = torch.tensor([1, 3]) pdist = torch.tensor([1, 3])
actual, _ = initializers.stratified_random(self.x, self.y, pdist, actual, _ = initializers.stratified_random(self.x, self.y, pdist,
False) False)
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.], desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
[0., 0., 0.]]) [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
@ -444,8 +448,8 @@ class TestInitializers(unittest.TestCase):
def test_stratified_mean_unequal_one_hot(self): def test_stratified_mean_unequal_one_hot(self):
pdist = torch.tensor([1, 3]) pdist = torch.tensor([1, 3])
y = torch.eye(2)[self.y] y = torch.eye(2)[self.y]
desired1 = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.], desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
[1., 1., 1.]]) [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist) actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]]) desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
mismatch = np.testing.assert_array_almost_equal(actual1, mismatch = np.testing.assert_array_almost_equal(actual1,
@ -460,8 +464,8 @@ class TestInitializers(unittest.TestCase):
pdist = torch.tensor([1, 3]) pdist = torch.tensor([1, 3])
y = torch.eye(2)[self.y] y = torch.eye(2)[self.y]
actual1, actual2 = initializers.stratified_random(self.x, y, pdist) actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
desired1 = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.], desired1 = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
[0., 0., 0.]]) [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]]) desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
mismatch = np.testing.assert_array_almost_equal(actual1, mismatch = np.testing.assert_array_almost_equal(actual1,
desired1, desired1,

View File

@ -29,10 +29,12 @@ class TestPrototypes(unittest.TestCase):
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1) _ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
def test_prototypes1d_init_without_pdist(self): def test_prototypes1d_init_without_pdist(self):
p1 = prototypes.Prototypes1D(input_dim=6, p1 = prototypes.Prototypes1D(
nclasses=2, input_dim=6,
prototypes_per_class=4, nclasses=2,
prototype_initializer='ones') prototypes_per_class=4,
prototype_initializer="ones",
)
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.ones(8, 6) desired = torch.ones(8, 6)
@ -45,7 +47,7 @@ class TestPrototypes(unittest.TestCase):
pdist = [2, 2] pdist = [2, 2]
p1 = prototypes.Prototypes1D(input_dim=3, p1 = prototypes.Prototypes1D(input_dim=3,
prototype_distribution=pdist, prototype_distribution=pdist,
prototype_initializer='zeros') prototype_initializer="zeros")
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.zeros(4, 3) desired = torch.zeros(4, 3)
@ -60,14 +62,15 @@ class TestPrototypes(unittest.TestCase):
input_dim=3, input_dim=3,
nclasses=2, nclasses=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer='stratified_mean', prototype_initializer="stratified_mean",
data=None) data=None,
)
def test_prototypes1d_init_torch_pdist(self): def test_prototypes1d_init_torch_pdist(self):
pdist = torch.tensor([2, 2]) pdist = torch.tensor([2, 2])
p1 = prototypes.Prototypes1D(input_dim=3, p1 = prototypes.Prototypes1D(input_dim=3,
prototype_distribution=pdist, prototype_distribution=pdist,
prototype_initializer='zeros') prototype_initializer="zeros")
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.zeros(4, 3) desired = torch.zeros(4, 3)
@ -77,24 +80,30 @@ class TestPrototypes(unittest.TestCase):
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_prototypes1d_init_without_inputdim_with_data(self): def test_prototypes1d_init_without_inputdim_with_data(self):
_ = prototypes.Prototypes1D(nclasses=2, _ = prototypes.Prototypes1D(
prototypes_per_class=1, nclasses=2,
prototype_initializer='stratified_mean', prototypes_per_class=1,
data=[[[1.], [0.]], [1, 0]]) prototype_initializer="stratified_mean",
data=[[[1.0], [0.0]], [1, 0]],
)
def test_prototypes1d_init_with_int_data(self): def test_prototypes1d_init_with_int_data(self):
_ = prototypes.Prototypes1D(nclasses=2, _ = prototypes.Prototypes1D(
prototypes_per_class=1, nclasses=2,
prototype_initializer='stratified_mean', prototypes_per_class=1,
data=[[[1], [0]], [1, 0]]) prototype_initializer="stratified_mean",
data=[[[1], [0]], [1, 0]],
)
def test_prototypes1d_init_one_hot_without_data(self): def test_prototypes1d_init_one_hot_without_data(self):
_ = prototypes.Prototypes1D(input_dim=1, _ = prototypes.Prototypes1D(
nclasses=2, input_dim=1,
prototypes_per_class=1, nclasses=2,
prototype_initializer='stratified_mean', prototypes_per_class=1,
data=None, prototype_initializer="stratified_mean",
one_hot_labels=True) data=None,
one_hot_labels=True,
)
def test_prototypes1d_init_one_hot_labels_false(self): def test_prototypes1d_init_one_hot_labels_false(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `False` """Test if ValueError is raised when `one_hot_labels` is set to `False`
@ -105,9 +114,10 @@ class TestPrototypes(unittest.TestCase):
input_dim=1, input_dim=1,
nclasses=2, nclasses=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer='stratified_mean', prototype_initializer="stratified_mean",
data=([[0.], [1.]], [[0, 1], [1, 0]]), data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
one_hot_labels=False) one_hot_labels=False,
)
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self): def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `True` """Test if ValueError is raised when `one_hot_labels` is set to `True`
@ -118,9 +128,10 @@ class TestPrototypes(unittest.TestCase):
input_dim=1, input_dim=1,
nclasses=2, nclasses=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer='stratified_mean', prototype_initializer="stratified_mean",
data=([[0.], [1.]], [0, 1]), data=([[0.0], [1.0]], [0, 1]),
one_hot_labels=True) one_hot_labels=True,
)
def test_prototypes1d_init_one_hot_labels_true(self): def test_prototypes1d_init_one_hot_labels_true(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `True` """Test if ValueError is raised when `one_hot_labels` is set to `True`
@ -132,25 +143,27 @@ class TestPrototypes(unittest.TestCase):
input_dim=1, input_dim=1,
nclasses=2, nclasses=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer='stratified_mean', prototype_initializer="stratified_mean",
data=([[0.], [1.]], [[0], [1]]), data=([[0.0], [1.0]], [[0], [1]]),
one_hot_labels=True) one_hot_labels=True,
)
def test_prototypes1d_init_with_int_dtype(self): def test_prototypes1d_init_with_int_dtype(self):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
nclasses=2, nclasses=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer='stratified_mean', prototype_initializer="stratified_mean",
data=[[[1], [0]], [1, 0]], data=[[[1], [0]], [1, 0]],
dtype=torch.int32) dtype=torch.int32,
)
def test_prototypes1d_inputndim_with_data(self): def test_prototypes1d_inputndim_with_data(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(input_dim=1, _ = prototypes.Prototypes1D(input_dim=1,
nclasses=1, nclasses=1,
prototypes_per_class=1, prototypes_per_class=1,
data=[[1.], [1]]) data=[[1.0], [1]])
def test_prototypes1d_inputdim_with_data(self): def test_prototypes1d_inputdim_with_data(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -158,8 +171,9 @@ class TestPrototypes(unittest.TestCase):
input_dim=2, input_dim=2,
nclasses=2, nclasses=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer='stratified_mean', prototype_initializer="stratified_mean",
data=[[[1.], [0.]], [1, 0]]) data=[[[1.0], [0.0]], [1, 0]],
)
def test_prototypes1d_nclasses_with_data(self): def test_prototypes1d_nclasses_with_data(self):
"""Test ValueError raise if provided `nclasses` is not the same """Test ValueError raise if provided `nclasses` is not the same
@ -170,13 +184,14 @@ class TestPrototypes(unittest.TestCase):
input_dim=1, input_dim=1,
nclasses=1, nclasses=1,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer='stratified_mean', prototype_initializer="stratified_mean",
data=[[[1.], [2.]], [1, 2]]) data=[[[1.0], [2.0]], [1, 2]],
)
def test_prototypes1d_init_with_ppc(self): def test_prototypes1d_init_with_ppc(self):
p1 = prototypes.Prototypes1D(data=[self.x, self.y], p1 = prototypes.Prototypes1D(data=[self.x, self.y],
prototypes_per_class=2, prototypes_per_class=2,
prototype_initializer='zeros') prototype_initializer="zeros")
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.zeros(4, 3) desired = torch.zeros(4, 3)
@ -186,9 +201,11 @@ class TestPrototypes(unittest.TestCase):
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_prototypes1d_init_with_pdist(self): def test_prototypes1d_init_with_pdist(self):
p1 = prototypes.Prototypes1D(data=[self.x, self.y], p1 = prototypes.Prototypes1D(
prototype_distribution=[6, 9], data=[self.x, self.y],
prototype_initializer='zeros') prototype_distribution=[6, 9],
prototype_initializer="zeros",
)
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.zeros(15, 3) desired = torch.zeros(15, 3)
@ -201,10 +218,12 @@ class TestPrototypes(unittest.TestCase):
def my_initializer(*args, **kwargs): def my_initializer(*args, **kwargs):
return torch.full((2, 99), 99.0), torch.tensor([0, 1]) return torch.full((2, 99), 99.0), torch.tensor([0, 1])
p1 = prototypes.Prototypes1D(input_dim=99, p1 = prototypes.Prototypes1D(
nclasses=2, input_dim=99,
prototypes_per_class=1, nclasses=2,
prototype_initializer=my_initializer) prototypes_per_class=1,
prototype_initializer=my_initializer,
)
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = 99 * torch.ones(2, 99) desired = 99 * torch.ones(2, 99)
@ -231,7 +250,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_validate_extra_repr_not_empty(self): def test_prototypes1d_validate_extra_repr_not_empty(self):
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0]) p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
rep = p1.extra_repr() rep = p1.extra_repr()
self.assertNotEqual(rep, '') self.assertNotEqual(rep, "")
def tearDown(self): def tearDown(self):
del self.x, self.y, self.gen del self.x, self.y, self.gen
@ -243,11 +262,11 @@ class TestLosses(unittest.TestCase):
pass pass
def test_glvqloss_init(self): def test_glvqloss_init(self):
_ = losses.GLVQLoss(0, 'swish_beta', beta=20) _ = losses.GLVQLoss(0, "swish_beta", beta=20)
def test_glvqloss_forward_1ppc(self): def test_glvqloss_forward_1ppc(self):
criterion = losses.GLVQLoss(margin=0, criterion = losses.GLVQLoss(margin=0,
squashing='sigmoid_beta', squashing="sigmoid_beta",
beta=100) beta=100)
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
labels = torch.tensor([0, 1]) labels = torch.tensor([0, 1])
@ -259,7 +278,7 @@ class TestLosses(unittest.TestCase):
def test_glvqloss_forward_2ppc(self): def test_glvqloss_forward_2ppc(self):
criterion = losses.GLVQLoss(margin=0, criterion = losses.GLVQLoss(margin=0,
squashing='sigmoid_beta', squashing="sigmoid_beta",
beta=100) beta=100)
d = torch.stack([ d = torch.stack([
torch.ones(100), torch.ones(100),