Automatic Formatting.
This commit is contained in:
parent
e1d56595c1
commit
7c30ffe2c7
@ -12,9 +12,8 @@
|
|||||||
#
|
#
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
sys.path.insert(0, os.path.abspath("../../"))
|
|
||||||
|
|
||||||
import sphinx_rtd_theme
|
sys.path.insert(0, os.path.abspath("../../"))
|
||||||
|
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
@ -128,15 +127,12 @@ latex_elements = {
|
|||||||
# The paper size ("letterpaper" or "a4paper").
|
# The paper size ("letterpaper" or "a4paper").
|
||||||
#
|
#
|
||||||
# "papersize": "letterpaper",
|
# "papersize": "letterpaper",
|
||||||
|
|
||||||
# The font size ("10pt", "11pt" or "12pt").
|
# The font size ("10pt", "11pt" or "12pt").
|
||||||
#
|
#
|
||||||
# "pointsize": "10pt",
|
# "pointsize": "10pt",
|
||||||
|
|
||||||
# Additional stuff for the LaTeX preamble.
|
# Additional stuff for the LaTeX preamble.
|
||||||
#
|
#
|
||||||
# "preamble": "",
|
# "preamble": "",
|
||||||
|
|
||||||
# Latex figure (float) alignment
|
# Latex figure (float) alignment
|
||||||
#
|
#
|
||||||
# "figure_align": "htbp",
|
# "figure_align": "htbp",
|
||||||
@ -146,15 +142,21 @@ latex_elements = {
|
|||||||
# (source start file, target name, title,
|
# (source start file, target name, title,
|
||||||
# author, documentclass [howto, manual, or own class]).
|
# author, documentclass [howto, manual, or own class]).
|
||||||
latex_documents = [
|
latex_documents = [
|
||||||
(master_doc, "prototorch.tex", "ProtoTorch Documentation",
|
(
|
||||||
"Jensun Ravichandran", "manual"),
|
master_doc,
|
||||||
|
"prototorch.tex",
|
||||||
|
"ProtoTorch Documentation",
|
||||||
|
"Jensun Ravichandran",
|
||||||
|
"manual",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# -- Options for manual page output ---------------------------------------
|
# -- Options for manual page output ---------------------------------------
|
||||||
|
|
||||||
# One entry per manual page. List of tuples
|
# One entry per manual page. List of tuples
|
||||||
# (source start file, name, description, authors, manual section).
|
# (source start file, name, description, authors, manual section).
|
||||||
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)]
|
man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author],
|
||||||
|
1)]
|
||||||
|
|
||||||
# -- Options for Texinfo output -------------------------------------------
|
# -- Options for Texinfo output -------------------------------------------
|
||||||
|
|
||||||
@ -162,9 +164,15 @@ man_pages = [(master_doc, "ProtoTorch", "ProtoTorch Documentation", [author], 1)
|
|||||||
# (source start file, target name, title, author,
|
# (source start file, target name, title, author,
|
||||||
# dir menu entry, description, category)
|
# dir menu entry, description, category)
|
||||||
texinfo_documents = [
|
texinfo_documents = [
|
||||||
(master_doc, "prototorch", "ProtoTorch Documentation", author, "prototorch",
|
(
|
||||||
"Prototype-based machine learning in PyTorch.",
|
master_doc,
|
||||||
"Miscellaneous"),
|
"prototorch",
|
||||||
|
"ProtoTorch Documentation",
|
||||||
|
author,
|
||||||
|
"prototorch",
|
||||||
|
"Prototype-based machine learning in PyTorch.",
|
||||||
|
"Miscellaneous",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Example configuration for intersphinx: refer to the Python standard library.
|
# Example configuration for intersphinx: refer to the Python standard library.
|
||||||
|
@ -3,13 +3,14 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from torchinfo import summary
|
||||||
|
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
from torchinfo import summary
|
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
# Prepare and preprocess the data
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
@ -29,7 +30,8 @@ class Model(torch.nn.Module):
|
|||||||
prototypes_per_class=3,
|
prototypes_per_class=3,
|
||||||
nclasses=3,
|
nclasses=3,
|
||||||
prototype_initializer="stratified_random",
|
prototype_initializer="stratified_random",
|
||||||
data=[x_train, y_train])
|
data=[x_train, y_train],
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.proto_layer.prototypes
|
protos = self.proto_layer.prototypes
|
||||||
@ -61,8 +63,10 @@ for epoch in range(70):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = wtac(dis, plabels)
|
pred = wtac(dis, plabels)
|
||||||
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
correct = pred.eq(y_in.view_as(pred)).sum().item()
|
||||||
acc = 100. * correct / len(x_train)
|
acc = 100.0 * correct / len(x_train)
|
||||||
print(f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%")
|
print(
|
||||||
|
f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} Acc: {acc:05.02f}%"
|
||||||
|
)
|
||||||
|
|
||||||
# Take a gradient descent step
|
# Take a gradient descent step
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -83,13 +87,15 @@ for epoch in range(70):
|
|||||||
ax.set_ylabel("Data dimension 2")
|
ax.set_ylabel("Data dimension 2")
|
||||||
cmap = "viridis"
|
cmap = "viridis"
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||||
ax.scatter(protos[:, 0],
|
ax.scatter(
|
||||||
protos[:, 1],
|
protos[:, 0],
|
||||||
c=plabels,
|
protos[:, 1],
|
||||||
cmap=cmap,
|
c=plabels,
|
||||||
edgecolor="k",
|
cmap=cmap,
|
||||||
marker="D",
|
edgecolor="k",
|
||||||
s=50)
|
marker="D",
|
||||||
|
s=50,
|
||||||
|
)
|
||||||
|
|
||||||
# Paint decision regions
|
# Paint decision regions
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
|
@ -20,11 +20,13 @@ class Model(torch.nn.Module):
|
|||||||
"""GMLVQ model as a siamese network."""
|
"""GMLVQ model as a siamese network."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
x, y = train_data.data, train_data.targets
|
x, y = train_data.data, train_data.targets
|
||||||
self.p1 = Prototypes1D(input_dim=100,
|
self.p1 = Prototypes1D(
|
||||||
prototypes_per_class=2,
|
input_dim=100,
|
||||||
nclasses=2,
|
prototypes_per_class=2,
|
||||||
prototype_initializer="stratified_random",
|
nclasses=2,
|
||||||
data=[x, y])
|
prototype_initializer="stratified_random",
|
||||||
|
data=[x, y],
|
||||||
|
)
|
||||||
self.omega = torch.nn.Linear(in_features=100,
|
self.omega = torch.nn.Linear(in_features=100,
|
||||||
out_features=100,
|
out_features=100,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
@ -13,8 +13,9 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||||
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.models import GTLVQ
|
from prototorch.modules.models import GTLVQ
|
||||||
|
|
||||||
# Parameters and options
|
# Parameters and options
|
||||||
@ -26,32 +27,40 @@ momentum = 0.5
|
|||||||
log_interval = 10
|
log_interval = 10
|
||||||
cuda = "cuda:1"
|
cuda = "cuda:1"
|
||||||
random_seed = 1
|
random_seed = 1
|
||||||
device = torch.device(cuda if torch.cuda.is_available() else 'cpu')
|
device = torch.device(cuda if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Configures reproducability
|
# Configures reproducability
|
||||||
torch.manual_seed(random_seed)
|
torch.manual_seed(random_seed)
|
||||||
np.random.seed(random_seed)
|
np.random.seed(random_seed)
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
# Prepare and preprocess the data
|
||||||
train_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
'./files/',
|
torchvision.datasets.MNIST(
|
||||||
train=True,
|
"./files/",
|
||||||
download=True,
|
train=True,
|
||||||
transform=torchvision.transforms.Compose(
|
download=True,
|
||||||
[transforms.ToTensor(),
|
transform=torchvision.transforms.Compose([
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
transforms.ToTensor(),
|
||||||
batch_size=batch_size_train,
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||||
shuffle=True)
|
]),
|
||||||
|
),
|
||||||
|
batch_size=batch_size_train,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST(
|
test_loader = torch.utils.data.DataLoader(
|
||||||
'./files/',
|
torchvision.datasets.MNIST(
|
||||||
train=False,
|
"./files/",
|
||||||
download=True,
|
train=False,
|
||||||
transform=torchvision.transforms.Compose(
|
download=True,
|
||||||
[transforms.ToTensor(),
|
transform=torchvision.transforms.Compose([
|
||||||
transforms.Normalize((0.1307, ), (0.3081, ))])),
|
transforms.ToTensor(),
|
||||||
batch_size=batch_size_test,
|
transforms.Normalize((0.1307, ), (0.3081, ))
|
||||||
shuffle=True)
|
]),
|
||||||
|
),
|
||||||
|
batch_size=batch_size_test,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Define the GLVQ model plus appropriate feature extractor
|
# Define the GLVQ model plus appropriate feature extractor
|
||||||
@ -67,25 +76,34 @@ class CNNGTLVQ(torch.nn.Module):
|
|||||||
):
|
):
|
||||||
super(CNNGTLVQ, self).__init__()
|
super(CNNGTLVQ, self).__init__()
|
||||||
|
|
||||||
#Feature Extractor - Simple CNN
|
# Feature Extractor - Simple CNN
|
||||||
self.fe = nn.Sequential(nn.Conv2d(1, 32, 3, 1), nn.ReLU(),
|
self.fe = nn.Sequential(
|
||||||
nn.Conv2d(32, 64, 3, 1), nn.ReLU(),
|
nn.Conv2d(1, 32, 3, 1),
|
||||||
nn.MaxPool2d(2), nn.Dropout(0.25),
|
nn.ReLU(),
|
||||||
nn.Flatten(), nn.Linear(9216, bottleneck_dim),
|
nn.Conv2d(32, 64, 3, 1),
|
||||||
nn.Dropout(0.5), nn.LeakyReLU(),
|
nn.ReLU(),
|
||||||
nn.LayerNorm(bottleneck_dim))
|
nn.MaxPool2d(2),
|
||||||
|
nn.Dropout(0.25),
|
||||||
|
nn.Flatten(),
|
||||||
|
nn.Linear(9216, bottleneck_dim),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.LeakyReLU(),
|
||||||
|
nn.LayerNorm(bottleneck_dim),
|
||||||
|
)
|
||||||
|
|
||||||
# Forward pass of subspace and prototype initialization data through feature extractor
|
# Forward pass of subspace and prototype initialization data through feature extractor
|
||||||
subspace_data = self.fe(subspace_data)
|
subspace_data = self.fe(subspace_data)
|
||||||
prototype_data[0] = self.fe(prototype_data[0])
|
prototype_data[0] = self.fe(prototype_data[0])
|
||||||
|
|
||||||
# Initialization of GTLVQ
|
# Initialization of GTLVQ
|
||||||
self.gtlvq = GTLVQ(num_classes,
|
self.gtlvq = GTLVQ(
|
||||||
subspace_data,
|
num_classes,
|
||||||
prototype_data,
|
subspace_data,
|
||||||
tangent_projection_type=tangent_projection_type,
|
prototype_data,
|
||||||
feature_dim=bottleneck_dim,
|
tangent_projection_type=tangent_projection_type,
|
||||||
prototypes_per_class=prototypes_per_class)
|
feature_dim=bottleneck_dim,
|
||||||
|
prototypes_per_class=prototypes_per_class,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Feature Extraction
|
# Feature Extraction
|
||||||
@ -103,20 +121,24 @@ subspace_data = torch.cat(
|
|||||||
prototype_data = next(iter(train_loader))
|
prototype_data = next(iter(train_loader))
|
||||||
|
|
||||||
# Build the CNN GTLVQ model
|
# Build the CNN GTLVQ model
|
||||||
model = CNNGTLVQ(10,
|
model = CNNGTLVQ(
|
||||||
subspace_data,
|
10,
|
||||||
prototype_data,
|
subspace_data,
|
||||||
tangent_projection_type="local",
|
prototype_data,
|
||||||
bottleneck_dim=128).to(device)
|
tangent_projection_type="local",
|
||||||
|
bottleneck_dim=128,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
# Optimize using SGD optimizer from `torch.optim`
|
# Optimize using SGD optimizer from `torch.optim`
|
||||||
optimizer = torch.optim.Adam([{
|
optimizer = torch.optim.Adam(
|
||||||
'params': model.fe.parameters()
|
[{
|
||||||
}, {
|
"params": model.fe.parameters()
|
||||||
'params': model.gtlvq.parameters()
|
}, {
|
||||||
}],
|
"params": model.gtlvq.parameters()
|
||||||
lr=learning_rate)
|
}],
|
||||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
lr=learning_rate,
|
||||||
|
)
|
||||||
|
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
for epoch in range(n_epochs):
|
for epoch in range(n_epochs):
|
||||||
@ -139,8 +161,8 @@ for epoch in range(n_epochs):
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
||||||
print(
|
print(
|
||||||
f'Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
f"Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||||
Train Acc: {acc.item():02.02f}')
|
Train Acc: {acc.item():02.02f}")
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -154,9 +176,9 @@ for epoch in range(n_epochs):
|
|||||||
i = torch.argmin(test_distances, 1)
|
i = torch.argmin(test_distances, 1)
|
||||||
correct += torch.sum(y_test == test_plabels[i])
|
correct += torch.sum(y_test == test_plabels[i])
|
||||||
total += y_test.size(0)
|
total += y_test.size(0)
|
||||||
print('Accuracy of the network on the test images: %d %%' %
|
print("Accuracy of the network on the test images: %d %%" %
|
||||||
(torch.true_divide(correct, total) * 100))
|
(torch.true_divide(correct, total) * 100))
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
PATH = './glvq_mnist_model.pth'
|
PATH = "./glvq_mnist_model.pth"
|
||||||
torch.save(model.state_dict(), PATH)
|
torch.save(model.state_dict(), PATH)
|
||||||
|
@ -22,10 +22,12 @@ class Model(torch.nn.Module):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Local-GMLVQ model."""
|
"""Local-GMLVQ model."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.p1 = Prototypes1D(input_dim=2,
|
self.p1 = Prototypes1D(
|
||||||
prototype_distribution=[1, 2, 2],
|
input_dim=2,
|
||||||
prototype_initializer="stratified_random",
|
prototype_distribution=[1, 2, 2],
|
||||||
data=[x_train, y_train])
|
prototype_initializer="stratified_random",
|
||||||
|
data=[x_train, y_train],
|
||||||
|
)
|
||||||
omegas = torch.zeros(5, 2, 2)
|
omegas = torch.zeros(5, 2, 2)
|
||||||
self.omegas = torch.nn.Parameter(omegas)
|
self.omegas = torch.nn.Parameter(omegas)
|
||||||
eye_(self.omegas)
|
eye_(self.omegas)
|
||||||
@ -76,14 +78,16 @@ for epoch in range(100):
|
|||||||
ax.set_xlabel("Data dimension 1")
|
ax.set_xlabel("Data dimension 1")
|
||||||
ax.set_ylabel("Data dimension 2")
|
ax.set_ylabel("Data dimension 2")
|
||||||
cmap = "viridis"
|
cmap = "viridis"
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
|
||||||
ax.scatter(protos[:, 0],
|
ax.scatter(
|
||||||
protos[:, 1],
|
protos[:, 0],
|
||||||
c=plabels,
|
protos[:, 1],
|
||||||
cmap=cmap,
|
c=plabels,
|
||||||
edgecolor='k',
|
cmap=cmap,
|
||||||
marker='D',
|
edgecolor="k",
|
||||||
s=50)
|
marker="D",
|
||||||
|
s=50,
|
||||||
|
)
|
||||||
|
|
||||||
# Paint decision regions
|
# Paint decision regions
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
|
@ -5,8 +5,6 @@
|
|||||||
# #############################################
|
# #############################################
|
||||||
__version__ = "0.3.0-dev0"
|
__version__ = "0.3.0-dev0"
|
||||||
|
|
||||||
from prototorch import datasets, functions, modules
|
|
||||||
|
|
||||||
__all_core__ = [
|
__all_core__ = [
|
||||||
"datasets",
|
"datasets",
|
||||||
"functions",
|
"functions",
|
||||||
@ -17,6 +15,7 @@ __all_core__ = [
|
|||||||
# Plugin Loader
|
# Plugin Loader
|
||||||
# #############################################
|
# #############################################
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__)
|
__path__ = pkgutil.extend_path(__path__, __name__)
|
||||||
@ -25,7 +24,8 @@ __path__ = pkgutil.extend_path(__path__, __name__)
|
|||||||
def discover_plugins():
|
def discover_plugins():
|
||||||
return {
|
return {
|
||||||
entry_point.name: entry_point.load()
|
entry_point.name: entry_point.load()
|
||||||
for entry_point in pkg_resources.iter_entry_points("prototorch.plugins")
|
for entry_point in pkg_resources.iter_entry_points(
|
||||||
|
"prototorch.plugins")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -33,14 +33,12 @@ discovered_plugins = discover_plugins()
|
|||||||
locals().update(discovered_plugins)
|
locals().update(discovered_plugins)
|
||||||
|
|
||||||
# Generate combines __version__ and __all__
|
# Generate combines __version__ and __all__
|
||||||
version_plugins = "\n".join(
|
version_plugins = "\n".join([
|
||||||
[
|
"- " + name + ": v" + plugin.__version__
|
||||||
"- " + name + ": v" + plugin.__version__
|
for name, plugin in discovered_plugins.items()
|
||||||
for name, plugin in discovered_plugins.items()
|
])
|
||||||
]
|
|
||||||
)
|
|
||||||
if version_plugins != "":
|
if version_plugins != "":
|
||||||
version_plugins = "\nPlugins: \n" + version_plugins
|
version_plugins = "\nPlugins: \n" + version_plugins
|
||||||
|
|
||||||
version = "core: v" + __version__ + version_plugins
|
version = "core: v" + __version__ + version_plugins
|
||||||
__all__ = __all_core__ + list(discovered_plugins.keys())
|
__all__ = __all_core__ + list(discovered_plugins.keys())
|
||||||
|
@ -3,5 +3,5 @@
|
|||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Tecator',
|
"Tecator",
|
||||||
]
|
]
|
||||||
|
@ -52,7 +52,8 @@ class Tecator(ProtoDataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_resources = [
|
_resources = [
|
||||||
("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0", "ba5607c580d0f91bb27dc29d13c2f8df"),
|
("1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0",
|
||||||
|
"ba5607c580d0f91bb27dc29d13c2f8df"),
|
||||||
] # (google_storage_id, md5hash)
|
] # (google_storage_id, md5hash)
|
||||||
classes = ["0 - low_fat", "1 - high_fat"]
|
classes = ["0 - low_fat", "1 - high_fat"]
|
||||||
|
|
||||||
@ -74,15 +75,15 @@ class Tecator(ProtoDataset):
|
|||||||
print("Downloading...")
|
print("Downloading...")
|
||||||
for fileid, md5 in self._resources:
|
for fileid, md5 in self._resources:
|
||||||
filename = "tecator.npz"
|
filename = "tecator.npz"
|
||||||
download_file_from_google_drive(
|
download_file_from_google_drive(fileid,
|
||||||
fileid, root=self.raw_folder, filename=filename, md5=md5
|
root=self.raw_folder,
|
||||||
)
|
filename=filename,
|
||||||
|
md5=md5)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print("Processing...")
|
print("Processing...")
|
||||||
with np.load(
|
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
||||||
os.path.join(self.raw_folder, "tecator.npz"), allow_pickle=False
|
allow_pickle=False) as f:
|
||||||
) as f:
|
|
||||||
x_train, y_train = f["x_train"], f["y_train"]
|
x_train, y_train = f["x_train"], f["y_train"]
|
||||||
x_test, y_test = f["x_test"], f["y_test"]
|
x_test, y_test = f["x_test"], f["y_test"]
|
||||||
training_set = [
|
training_set = [
|
||||||
@ -94,9 +95,11 @@ class Tecator(ProtoDataset):
|
|||||||
torch.tensor(y_test),
|
torch.tensor(y_test),
|
||||||
]
|
]
|
||||||
|
|
||||||
with open(os.path.join(self.processed_folder, self.training_file), "wb") as f:
|
with open(os.path.join(self.processed_folder, self.training_file),
|
||||||
|
"wb") as f:
|
||||||
torch.save(training_set, f)
|
torch.save(training_set, f)
|
||||||
with open(os.path.join(self.processed_folder, self.test_file), "wb") as f:
|
with open(os.path.join(self.processed_folder, self.test_file),
|
||||||
|
"wb") as f:
|
||||||
torch.save(test_set, f)
|
torch.save(test_set, f)
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
@ -4,9 +4,9 @@ from .activations import identity, sigmoid_beta, swish_beta
|
|||||||
from .competitions import knnc, wtac
|
from .competitions import knnc, wtac
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'identity',
|
"identity",
|
||||||
'sigmoid_beta',
|
"sigmoid_beta",
|
||||||
'swish_beta',
|
"swish_beta",
|
||||||
'knnc',
|
"knnc",
|
||||||
'wtac',
|
"wtac",
|
||||||
]
|
]
|
||||||
|
@ -61,4 +61,4 @@ def get_activation(funcname):
|
|||||||
return funcname
|
return funcname
|
||||||
if funcname in ACTIVATIONS:
|
if funcname in ACTIVATIONS:
|
||||||
return ACTIVATIONS.get(funcname)
|
return ACTIVATIONS.get(funcname)
|
||||||
raise NameError(f'Activation {funcname} was not found.')
|
raise NameError(f"Activation {funcname} was not found.")
|
||||||
|
@ -12,7 +12,7 @@ def stratified_min(distances, labels):
|
|||||||
return distances
|
return distances
|
||||||
batch_size = distances.size()[0]
|
batch_size = distances.size()[0]
|
||||||
winning_distances = torch.zeros(nclasses, batch_size)
|
winning_distances = torch.zeros(nclasses, batch_size)
|
||||||
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
inf = torch.full_like(distances.T, fill_value=float("inf"))
|
||||||
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||||
for i, cl in enumerate(clabels):
|
for i, cl in enumerate(clabels):
|
||||||
# cdists = distances.T[labels == cl]
|
# cdists = distances.T[labels == cl]
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
"""ProtoTorch distance functions."""
|
"""ProtoTorch distance functions."""
|
||||||
|
|
||||||
import torch
|
|
||||||
from prototorch.functions.helper import (
|
|
||||||
equal_int_shape,
|
|
||||||
_int_and_mixed_shape,
|
|
||||||
_check_shapes,
|
|
||||||
)
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||||
|
equal_int_shape)
|
||||||
|
|
||||||
|
|
||||||
def squared_euclidean_distance(x, y):
|
def squared_euclidean_distance(x, y):
|
||||||
|
@ -23,7 +23,7 @@ def predict_label(y_pred, plabels):
|
|||||||
|
|
||||||
def mixed_shape(inputs):
|
def mixed_shape(inputs):
|
||||||
if not torch.is_tensor(inputs):
|
if not torch.is_tensor(inputs):
|
||||||
raise ValueError('Input must be a tensor.')
|
raise ValueError("Input must be a tensor.")
|
||||||
else:
|
else:
|
||||||
int_shape = list(inputs.shape)
|
int_shape = list(inputs.shape)
|
||||||
# sometimes int_shape returns mixed integer types
|
# sometimes int_shape returns mixed integer types
|
||||||
@ -39,11 +39,11 @@ def mixed_shape(inputs):
|
|||||||
def equal_int_shape(shape_1, shape_2):
|
def equal_int_shape(shape_1, shape_2):
|
||||||
if not isinstance(shape_1,
|
if not isinstance(shape_1,
|
||||||
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
(tuple, list)) or not isinstance(shape_2, (tuple, list)):
|
||||||
raise ValueError('Input shapes must list or tuple.')
|
raise ValueError("Input shapes must list or tuple.")
|
||||||
for shape in [shape_1, shape_2]:
|
for shape in [shape_1, shape_2]:
|
||||||
if not all([isinstance(x, int) or x is None for x in shape]):
|
if not all([isinstance(x, int) or x is None for x in shape]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Input shapes must be list or tuple of int and None values.')
|
"Input shapes must be list or tuple of int and None values.")
|
||||||
|
|
||||||
if len(shape_1) != len(shape_2):
|
if len(shape_1) != len(shape_2):
|
||||||
return False
|
return False
|
||||||
|
@ -104,4 +104,4 @@ def get_initializer(funcname):
|
|||||||
return funcname
|
return funcname
|
||||||
if funcname in INITIALIZERS:
|
if funcname in INITIALIZERS:
|
||||||
return INITIALIZERS.get(funcname)
|
return INITIALIZERS.get(funcname)
|
||||||
raise NameError(f'Initializer {funcname} was not found.')
|
raise NameError(f"Initializer {funcname} was not found.")
|
||||||
|
@ -11,7 +11,7 @@ def _get_dp_dm(distances, targets, plabels):
|
|||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
not_matcher = torch.bitwise_not(matcher)
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
|
|
||||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
inf = torch.full_like(distances, fill_value=float("inf"))
|
||||||
d_matching = torch.where(matcher, distances, inf)
|
d_matching = torch.where(matcher, distances, inf)
|
||||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||||
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from __future__ import print_function
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -3,5 +3,5 @@
|
|||||||
from .prototypes import Prototypes1D
|
from .prototypes import Prototypes1D
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Prototypes1D',
|
"Prototypes1D",
|
||||||
]
|
]
|
||||||
|
@ -7,7 +7,7 @@ from prototorch.functions.losses import glvq_loss
|
|||||||
|
|
||||||
|
|
||||||
class GLVQLoss(torch.nn.Module):
|
class GLVQLoss(torch.nn.Module):
|
||||||
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.squashing = get_activation(squashing)
|
self.squashing = get_activation(squashing)
|
||||||
@ -37,4 +37,4 @@ class NeuralGasEnergy(torch.nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _nghood_fn(rankings, lm):
|
def _nghood_fn(rankings, lm):
|
||||||
return torch.exp(-rankings / lm)
|
return torch.exp(-rankings / lm)
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from torch import nn
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from torch import nn
|
||||||
from prototorch.functions.distances import tangent_distance, euclidean_distance_matrix
|
|
||||||
from prototorch.functions.normalization import orthogonalization
|
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||||||
|
tangent_distance)
|
||||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||||
|
from prototorch.functions.normalization import orthogonalization
|
||||||
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
|
||||||
|
|
||||||
class GTLVQ(nn.Module):
|
class GTLVQ(nn.Module):
|
||||||
@ -71,7 +73,7 @@ class GTLVQ(nn.Module):
|
|||||||
subspace_data=None,
|
subspace_data=None,
|
||||||
prototype_data=None,
|
prototype_data=None,
|
||||||
subspace_size=256,
|
subspace_size=256,
|
||||||
tangent_projection_type='local',
|
tangent_projection_type="local",
|
||||||
prototypes_per_class=2,
|
prototypes_per_class=2,
|
||||||
feature_dim=256,
|
feature_dim=256,
|
||||||
):
|
):
|
||||||
@ -82,37 +84,39 @@ class GTLVQ(nn.Module):
|
|||||||
self.feature_dim = feature_dim
|
self.feature_dim = feature_dim
|
||||||
|
|
||||||
if subspace_data is None:
|
if subspace_data is None:
|
||||||
raise ValueError('Init Data must be specified!')
|
raise ValueError("Init Data must be specified!")
|
||||||
|
|
||||||
self.tpt = tangent_projection_type
|
self.tpt = tangent_projection_type
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.tpt == 'local' or self.tpt == 'local_proj':
|
if self.tpt == "local" or self.tpt == "local_proj":
|
||||||
self.init_local_subspace(subspace_data)
|
self.init_local_subspace(subspace_data)
|
||||||
elif self.tpt == 'global':
|
elif self.tpt == "global":
|
||||||
self.init_gobal_subspace(subspace_data, subspace_size)
|
self.init_gobal_subspace(subspace_data, subspace_size)
|
||||||
else:
|
else:
|
||||||
self.subspaces = None
|
self.subspaces = None
|
||||||
|
|
||||||
# Hypothesis-Margin-Classifier
|
# Hypothesis-Margin-Classifier
|
||||||
self.cls = Prototypes1D(input_dim=feature_dim,
|
self.cls = Prototypes1D(
|
||||||
prototypes_per_class=prototypes_per_class,
|
input_dim=feature_dim,
|
||||||
nclasses=num_classes,
|
prototypes_per_class=prototypes_per_class,
|
||||||
prototype_initializer='stratified_mean',
|
nclasses=num_classes,
|
||||||
data=prototype_data)
|
prototype_initializer="stratified_mean",
|
||||||
|
data=prototype_data,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Tangent Projection
|
# Tangent Projection
|
||||||
if self.tpt == 'local_proj':
|
if self.tpt == "local_proj":
|
||||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||||
1).unsqueeze(2)
|
1).unsqueeze(2))
|
||||||
dis, proj_x = self.local_tangent_projection(x_conform)
|
dis, proj_x = self.local_tangent_projection(x_conform)
|
||||||
|
|
||||||
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
proj_x = proj_x.reshape(x.shape[0] * self.num_protos,
|
||||||
self.feature_dim)
|
self.feature_dim)
|
||||||
return proj_x, dis
|
return proj_x, dis
|
||||||
elif self.tpt == "local":
|
elif self.tpt == "local":
|
||||||
x_conform = x.unsqueeze(1).repeat_interleave(self.num_protos,
|
x_conform = (x.unsqueeze(1).repeat_interleave(self.num_protos,
|
||||||
1).unsqueeze(2)
|
1).unsqueeze(2))
|
||||||
dis = tangent_distance(x_conform, self.cls.prototypes,
|
dis = tangent_distance(x_conform, self.cls.prototypes,
|
||||||
self.subspaces)
|
self.subspaces)
|
||||||
elif self.tpt == "gloabl":
|
elif self.tpt == "gloabl":
|
||||||
@ -127,25 +131,27 @@ class GTLVQ(nn.Module):
|
|||||||
_, _, v = torch.svd(data)
|
_, _, v = torch.svd(data)
|
||||||
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
subspace = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||||
subspaces = subspace[:, :num_subspaces]
|
subspaces = subspace[:, :num_subspaces]
|
||||||
self.subspaces = torch.nn.Parameter(
|
self.subspaces = (torch.nn.Parameter(
|
||||||
subspaces).clone().detach().requires_grad_(True)
|
subspaces).clone().detach().requires_grad_(True))
|
||||||
|
|
||||||
def init_local_subspace(self, data):
|
def init_local_subspace(self, data):
|
||||||
_, _, v = torch.svd(data)
|
_, _, v = torch.svd(data)
|
||||||
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
inital_projector = (torch.eye(v.shape[0]) - (v @ v.T)).T
|
||||||
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
subspaces = inital_projector.unsqueeze(0).repeat_interleave(
|
||||||
self.num_protos, 0)
|
self.num_protos, 0)
|
||||||
self.subspaces = torch.nn.Parameter(
|
self.subspaces = (torch.nn.Parameter(
|
||||||
subspaces).clone().detach().requires_grad_(True)
|
subspaces).clone().detach().requires_grad_(True))
|
||||||
|
|
||||||
def global_tangent_distances(self, x):
|
def global_tangent_distances(self, x):
|
||||||
# Tangent Projection
|
# Tangent Projection
|
||||||
x, projected_prototypes = x @ self.subspaces, self.cls.prototypes @ self.subspaces
|
x, projected_prototypes = (
|
||||||
|
x @ self.subspaces,
|
||||||
|
self.cls.prototypes @ self.subspaces,
|
||||||
|
)
|
||||||
# Euclidean Distance
|
# Euclidean Distance
|
||||||
return euclidean_distance_matrix(x, projected_prototypes)
|
return euclidean_distance_matrix(x, projected_prototypes)
|
||||||
|
|
||||||
def local_tangent_projection(self,
|
def local_tangent_projection(self, signals):
|
||||||
signals):
|
|
||||||
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
# Note: subspaces is always assumed as transposed and must be orthogonal!
|
||||||
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
# shape(signals): batch x proto_number x channels x dim1 x dim2 x ... x dimN
|
||||||
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
# shape(protos): proto_number x dim1 x dim2 x ... x dimN
|
||||||
@ -183,8 +189,7 @@ class GTLVQ(nn.Module):
|
|||||||
def orthogonalize_subspace(self):
|
def orthogonalize_subspace(self):
|
||||||
if self.subspaces is not None:
|
if self.subspaces is not None:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ortho_subpsaces = orthogonalization(
|
ortho_subpsaces = (orthogonalization(self.subspaces)
|
||||||
self.subspaces
|
if self.tpt == "global" else
|
||||||
) if self.tpt == 'global' else torch.nn.init.orthogonal_(
|
torch.nn.init.orthogonal_(self.subspaces))
|
||||||
self.subspaces)
|
|
||||||
self.subspaces.copy_(ortho_subpsaces)
|
self.subspaces.copy_(ortho_subpsaces)
|
||||||
|
@ -29,14 +29,16 @@ class Prototypes1D(_Prototypes):
|
|||||||
|
|
||||||
TODO Complete this doc-string.
|
TODO Complete this doc-string.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(
|
||||||
prototypes_per_class=1,
|
self,
|
||||||
prototype_initializer="ones",
|
prototypes_per_class=1,
|
||||||
prototype_distribution=None,
|
prototype_initializer="ones",
|
||||||
data=None,
|
prototype_distribution=None,
|
||||||
dtype=torch.float32,
|
data=None,
|
||||||
one_hot_labels=False,
|
dtype=torch.float32,
|
||||||
**kwargs):
|
one_hot_labels=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
# Convert tensors to python lists before processing
|
# Convert tensors to python lists before processing
|
||||||
if prototype_distribution is not None:
|
if prototype_distribution is not None:
|
||||||
|
@ -1 +0,0 @@
|
|||||||
from .colors import color_scheme, get_legend_handles
|
|
@ -1,13 +1,13 @@
|
|||||||
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
|
"""Easy matplotlib animation. From https://github.com/jwkvam/celluloid."""
|
||||||
|
|
||||||
from typing import Dict, List
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from matplotlib.artist import Artist
|
|
||||||
from matplotlib.animation import ArtistAnimation
|
from matplotlib.animation import ArtistAnimation
|
||||||
|
from matplotlib.artist import Artist
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
__version__ = '0.2.0'
|
__version__ = "0.2.0"
|
||||||
|
|
||||||
|
|
||||||
class Camera:
|
class Camera:
|
||||||
@ -19,7 +19,7 @@ class Camera:
|
|||||||
self._offsets: Dict[str, Dict[int, int]] = {
|
self._offsets: Dict[str, Dict[int, int]] = {
|
||||||
k: defaultdict(int)
|
k: defaultdict(int)
|
||||||
for k in
|
for k in
|
||||||
['collections', 'patches', 'lines', 'texts', 'artists', 'images']
|
["collections", "patches", "lines", "texts", "artists", "images"]
|
||||||
}
|
}
|
||||||
self._photos: List[List[Artist]] = []
|
self._photos: List[List[Artist]] = []
|
||||||
|
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
"""ProtoFlow color utilities."""
|
"""ProtoFlow color utilities."""
|
||||||
|
|
||||||
from matplotlib import cm
|
|
||||||
from matplotlib.colors import Normalize
|
|
||||||
from matplotlib.colors import to_hex
|
|
||||||
from matplotlib.colors import to_rgb
|
|
||||||
import matplotlib.lines as mlines
|
import matplotlib.lines as mlines
|
||||||
|
from matplotlib import cm
|
||||||
|
from matplotlib.colors import Normalize, to_hex, to_rgb
|
||||||
|
|
||||||
|
|
||||||
def color_scheme(n, cmap="viridis", form="hex", tikz=False,
|
def color_scheme(n,
|
||||||
|
cmap="viridis",
|
||||||
|
form="hex",
|
||||||
|
tikz=False,
|
||||||
zero_indexed=False):
|
zero_indexed=False):
|
||||||
"""Return *n* colors from the color scheme.
|
"""Return *n* colors from the color scheme.
|
||||||
|
|
||||||
@ -57,13 +58,16 @@ def get_legend_handles(labels, marker="dots", zero_indexed=False):
|
|||||||
zero_indexed=zero_indexed)
|
zero_indexed=zero_indexed)
|
||||||
for label, color in zip(labels, colors.values()):
|
for label, color in zip(labels, colors.values()):
|
||||||
if marker == "dots":
|
if marker == "dots":
|
||||||
handle = mlines.Line2D([], [],
|
handle = mlines.Line2D(
|
||||||
color="white",
|
[],
|
||||||
markerfacecolor=color,
|
[],
|
||||||
marker="o",
|
color="white",
|
||||||
markersize=10,
|
markerfacecolor=color,
|
||||||
markeredgecolor="k",
|
marker="o",
|
||||||
label=label)
|
markersize=10,
|
||||||
|
markeredgecolor="k",
|
||||||
|
label=label,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
handle = mlines.Line2D([], [],
|
handle = mlines.Line2D([], [],
|
||||||
color=color,
|
color=color,
|
||||||
|
@ -11,17 +11,17 @@ import numpy as np
|
|||||||
|
|
||||||
def progressbar(title, value, end, bar_width=20):
|
def progressbar(title, value, end, bar_width=20):
|
||||||
percent = float(value) / end
|
percent = float(value) / end
|
||||||
arrow = '=' * int(round(percent * bar_width) - 1) + '>'
|
arrow = "=" * int(round(percent * bar_width) - 1) + ">"
|
||||||
spaces = '.' * (bar_width - len(arrow))
|
spaces = "." * (bar_width - len(arrow))
|
||||||
sys.stdout.write('\r{}: [{}] {}%'.format(title, arrow + spaces,
|
sys.stdout.write("\r{}: [{}] {}%".format(title, arrow + spaces,
|
||||||
int(round(percent * 100))))
|
int(round(percent * 100))))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
if percent == 1.0:
|
if percent == 1.0:
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def prettify_string(inputs, start='', sep=' ', end='\n'):
|
def prettify_string(inputs, start="", sep=" ", end="\n"):
|
||||||
outputs = start + ' '.join(inputs.split()) + end
|
outputs = start + " ".join(inputs.split()) + end
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@ -29,22 +29,22 @@ def pretty_print(inputs):
|
|||||||
print(prettify_string(inputs))
|
print(prettify_string(inputs))
|
||||||
|
|
||||||
|
|
||||||
def writelog(self, *logs, logdir='./logs', logfile='run.txt'):
|
def writelog(self, *logs, logdir="./logs", logfile="run.txt"):
|
||||||
f = os.path.join(logdir, logfile)
|
f = os.path.join(logdir, logfile)
|
||||||
with open(f, 'a+') as fh:
|
with open(f, "a+") as fh:
|
||||||
for log in logs:
|
for log in logs:
|
||||||
fh.write(log)
|
fh.write(log)
|
||||||
fh.write('\n')
|
fh.write("\n")
|
||||||
|
|
||||||
|
|
||||||
def start_tensorboard(self, logdir='./logs'):
|
def start_tensorboard(self, logdir="./logs"):
|
||||||
cmd = f'tensorboard --logdir={logdir} --port=6006'
|
cmd = f"tensorboard --logdir={logdir} --port=6006"
|
||||||
os.system(cmd)
|
os.system(cmd)
|
||||||
|
|
||||||
|
|
||||||
def make_directory(save_dir):
|
def make_directory(save_dir):
|
||||||
if not os.path.exists(save_dir):
|
if not os.path.exists(save_dir):
|
||||||
print(f'Making directory {save_dir}.')
|
print(f"Making directory {save_dir}.")
|
||||||
os.mkdir(save_dir)
|
os.mkdir(save_dir)
|
||||||
|
|
||||||
|
|
||||||
@ -52,36 +52,36 @@ def make_gif(filenames, duration, output_file=None):
|
|||||||
try:
|
try:
|
||||||
import imageio
|
import imageio
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
print('Please install Protoflow with [other] extra requirements.')
|
print("Please install Protoflow with [other] extra requirements.")
|
||||||
raise (e)
|
raise (e)
|
||||||
|
|
||||||
images = list()
|
images = list()
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
images.append(imageio.imread(filename))
|
images.append(imageio.imread(filename))
|
||||||
if not output_file:
|
if not output_file:
|
||||||
output_file = f'makegif.gif'
|
output_file = f"makegif.gif"
|
||||||
if images:
|
if images:
|
||||||
imageio.mimwrite(output_file, images, duration=duration)
|
imageio.mimwrite(output_file, images, duration=duration)
|
||||||
|
|
||||||
|
|
||||||
def gif_from_dir(directory,
|
def gif_from_dir(directory,
|
||||||
duration,
|
duration,
|
||||||
prefix='',
|
prefix="",
|
||||||
output_file=None,
|
output_file=None,
|
||||||
verbose=True):
|
verbose=True):
|
||||||
images = os.listdir(directory)
|
images = os.listdir(directory)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Making gif from {len(images)} images under {directory}.')
|
print(f"Making gif from {len(images)} images under {directory}.")
|
||||||
filenames = list()
|
filenames = list()
|
||||||
# Sort images
|
# Sort images
|
||||||
images = sorted(
|
images = sorted(
|
||||||
images,
|
images,
|
||||||
key=lambda img: int(os.path.splitext(img)[0].replace(prefix, '')))
|
key=lambda img: int(os.path.splitext(img)[0].replace(prefix, "")))
|
||||||
for image in images:
|
for image in images:
|
||||||
fname = os.path.join(directory, image)
|
fname = os.path.join(directory, image)
|
||||||
filenames.append(fname)
|
filenames.append(fname)
|
||||||
if not output_file:
|
if not output_file:
|
||||||
output_file = os.path.join(directory, 'makegif.gif')
|
output_file = os.path.join(directory, "makegif.gif")
|
||||||
make_gif(filenames=filenames, duration=duration, output_file=output_file)
|
make_gif(filenames=filenames, duration=duration, output_file=output_file)
|
||||||
|
|
||||||
|
|
||||||
@ -95,12 +95,12 @@ def predict_and_score(clf,
|
|||||||
x_test,
|
x_test,
|
||||||
y_test,
|
y_test,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
title='Test accuracy'):
|
title="Test accuracy"):
|
||||||
y_pred = clf.predict(x_test)
|
y_pred = clf.predict(x_test)
|
||||||
accuracy = np.sum(y_test == y_pred)
|
accuracy = np.sum(y_test == y_pred)
|
||||||
normalized_acc = accuracy / float(len(y_test))
|
normalized_acc = accuracy / float(len(y_test))
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'{title}: {normalized_acc * 100:06.04f}%')
|
print(f"{title}: {normalized_acc * 100:06.04f}%")
|
||||||
return normalized_acc
|
return normalized_acc
|
||||||
|
|
||||||
|
|
||||||
@ -124,6 +124,7 @@ def replace_in(arr, replacement_dict, inplace=False):
|
|||||||
new_arr = arr
|
new_arr = arr
|
||||||
else:
|
else:
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
new_arr = copy.deepcopy(arr)
|
new_arr = copy.deepcopy(arr)
|
||||||
for k, v in replacement_dict.items():
|
for k, v in replacement_dict.items():
|
||||||
new_arr[arr == k] = v
|
new_arr[arr == k] = v
|
||||||
@ -135,7 +136,7 @@ def train_test_split(data, train=0.7, val=0.15, shuffle=None, return_xy=False):
|
|||||||
preserve the class distribution in subsamples of the dataset.
|
preserve the class distribution in subsamples of the dataset.
|
||||||
"""
|
"""
|
||||||
if train + val > 1.0:
|
if train + val > 1.0:
|
||||||
raise ValueError('Invalid split values for train and val.')
|
raise ValueError("Invalid split values for train and val.")
|
||||||
Y = data[:, -1]
|
Y = data[:, -1]
|
||||||
labels = set(Y)
|
labels = set(Y)
|
||||||
hist = dict()
|
hist = dict()
|
||||||
@ -183,20 +184,20 @@ def train_test_split(data, train=0.7, val=0.15, shuffle=None, return_xy=False):
|
|||||||
return train_data, val_data, test_data
|
return train_data, val_data, test_data
|
||||||
|
|
||||||
|
|
||||||
def class_histogram(data, title='Untitled'):
|
def class_histogram(data, title="Untitled"):
|
||||||
plt.figure(title)
|
plt.figure(title)
|
||||||
plt.clf()
|
plt.clf()
|
||||||
plt.title(title)
|
plt.title(title)
|
||||||
dist, counts = np.unique(data[:, -1], return_counts=True)
|
dist, counts = np.unique(data[:, -1], return_counts=True)
|
||||||
plt.bar(dist, counts)
|
plt.bar(dist, counts)
|
||||||
plt.xticks(dist)
|
plt.xticks(dist)
|
||||||
print('Call matplotlib.pyplot.show() to see the plot.')
|
print("Call matplotlib.pyplot.show() to see the plot.")
|
||||||
|
|
||||||
|
|
||||||
def ntimer(n=10):
|
def ntimer(n=10):
|
||||||
"""Wraps a function which wraps another function to time it."""
|
"""Wraps a function which wraps another function to time it."""
|
||||||
if n < 1:
|
if n < 1:
|
||||||
raise (Exception(f'Invalid n = {n} given.'))
|
raise (Exception(f"Invalid n = {n} given."))
|
||||||
|
|
||||||
def timer(func):
|
def timer(func):
|
||||||
"""Wraps `func` with a timer and returns the wrapped `func`."""
|
"""Wraps `func` with a timer and returns the wrapped `func`."""
|
||||||
@ -207,7 +208,7 @@ def ntimer(n=10):
|
|||||||
rv = func(*args, **kwargs)
|
rv = func(*args, **kwargs)
|
||||||
after = time()
|
after = time()
|
||||||
elapsed = after - before
|
elapsed = after - before
|
||||||
print(f'Elapsed: {elapsed*1e3:02.02f} ms')
|
print(f"Elapsed: {elapsed*1e3:02.02f} ms")
|
||||||
return rv
|
return rv
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
@ -228,15 +229,15 @@ def memoize(verbose=True):
|
|||||||
t = (pickle.dumps(args), pickle.dumps(kwargs))
|
t = (pickle.dumps(args), pickle.dumps(kwargs))
|
||||||
if t not in cache:
|
if t not in cache:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Adding NEW rv {func.__name__}{args}{kwargs} '
|
print(f"Adding NEW rv {func.__name__}{args}{kwargs} "
|
||||||
'to cache.')
|
"to cache.")
|
||||||
cache[t] = func(*args, **kwargs)
|
cache[t] = func(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Using OLD rv {func.__name__}{args}{kwargs} '
|
print(f"Using OLD rv {func.__name__}{args}{kwargs} "
|
||||||
'from cache.')
|
"from cache.")
|
||||||
return cache[t]
|
return cache[t]
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
return memoizer
|
return memoizer
|
||||||
|
3
setup.py
3
setup.py
@ -8,8 +8,7 @@
|
|||||||
|
|
||||||
ProtoTorch Core Package
|
ProtoTorch Core Package
|
||||||
"""
|
"""
|
||||||
from setuptools import setup
|
from setuptools import find_packages, setup
|
||||||
from setuptools import find_packages
|
|
||||||
|
|
||||||
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
||||||
|
@ -12,26 +12,26 @@ from prototorch.datasets import abstract, tecator
|
|||||||
class TestAbstract(unittest.TestCase):
|
class TestAbstract(unittest.TestCase):
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.Dataset('./artifacts')[0]
|
abstract.Dataset("./artifacts")[0]
|
||||||
|
|
||||||
def test_len(self):
|
def test_len(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
len(abstract.Dataset('./artifacts'))
|
len(abstract.Dataset("./artifacts"))
|
||||||
|
|
||||||
|
|
||||||
class TestProtoDataset(unittest.TestCase):
|
class TestProtoDataset(unittest.TestCase):
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.ProtoDataset('./artifacts')[0]
|
abstract.ProtoDataset("./artifacts")[0]
|
||||||
|
|
||||||
def test_download(self):
|
def test_download(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.ProtoDataset('./artifacts').download()
|
abstract.ProtoDataset("./artifacts").download()
|
||||||
|
|
||||||
|
|
||||||
class TestTecator(unittest.TestCase):
|
class TestTecator(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.artifacts_dir = './artifacts/Tecator'
|
self.artifacts_dir = "./artifacts/Tecator"
|
||||||
self._remove_artifacts()
|
self._remove_artifacts()
|
||||||
|
|
||||||
def _remove_artifacts(self):
|
def _remove_artifacts(self):
|
||||||
@ -39,23 +39,23 @@ class TestTecator(unittest.TestCase):
|
|||||||
shutil.rmtree(self.artifacts_dir)
|
shutil.rmtree(self.artifacts_dir)
|
||||||
|
|
||||||
def test_download_false(self):
|
def test_download_false(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
self._remove_artifacts()
|
self._remove_artifacts()
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
_ = tecator.Tecator(rootdir, download=False)
|
_ = tecator.Tecator(rootdir, download=False)
|
||||||
|
|
||||||
def test_download_caching(self):
|
def test_download_caching(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
||||||
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
||||||
|
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
||||||
self.assertTrue('Split: Train' in train.__repr__())
|
self.assertTrue("Split: Train" in train.__repr__())
|
||||||
|
|
||||||
def test_download_train(self):
|
def test_download_train(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
train = tecator.Tecator(root=rootdir,
|
train = tecator.Tecator(root=rootdir,
|
||||||
train=True,
|
train=True,
|
||||||
download=True,
|
download=True,
|
||||||
@ -67,7 +67,7 @@ class TestTecator(unittest.TestCase):
|
|||||||
self.assertEqual(x_train.shape[1], 100)
|
self.assertEqual(x_train.shape[1], 100)
|
||||||
|
|
||||||
def test_download_test(self):
|
def test_download_test(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
x_test, y_test = test.data, test.targets
|
x_test, y_test = test.data, test.targets
|
||||||
self.assertEqual(x_test.shape[0], 71)
|
self.assertEqual(x_test.shape[0], 71)
|
||||||
@ -75,19 +75,19 @@ class TestTecator(unittest.TestCase):
|
|||||||
self.assertEqual(x_test.shape[1], 100)
|
self.assertEqual(x_test.shape[1], 100)
|
||||||
|
|
||||||
def test_class_to_idx(self):
|
def test_class_to_idx(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
_ = test.class_to_idx
|
_ = test.class_to_idx
|
||||||
|
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
x, y = test[0]
|
x, y = test[0]
|
||||||
self.assertEqual(x.shape[0], 100)
|
self.assertEqual(x.shape[0], 100)
|
||||||
self.assertIsInstance(y, int)
|
self.assertIsInstance(y, int)
|
||||||
|
|
||||||
def test_loadable_with_dataloader(self):
|
def test_loadable_with_dataloader(self):
|
||||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from prototorch.functions import (activations, competitions, distances,
|
|||||||
|
|
||||||
class TestActivations(unittest.TestCase):
|
class TestActivations(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.flist = ['identity', 'sigmoid_beta', 'swish_beta']
|
self.flist = ["identity", "sigmoid_beta", "swish_beta"]
|
||||||
self.x = torch.randn(1024, 1)
|
self.x = torch.randn(1024, 1)
|
||||||
|
|
||||||
def test_registry(self):
|
def test_registry(self):
|
||||||
@ -39,7 +39,7 @@ class TestActivations(unittest.TestCase):
|
|||||||
self.assertEqual(1, f(1))
|
self.assertEqual(1, f(1))
|
||||||
|
|
||||||
def test_unknown_deserialization(self):
|
def test_unknown_deserialization(self):
|
||||||
for funcname in ['blubb', 'foobar']:
|
for funcname in ["blubb", "foobar"]:
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = activations.get_activation(funcname)
|
_ = activations.get_activation(funcname)
|
||||||
|
|
||||||
@ -76,7 +76,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_wtac(self):
|
def test_wtac(self):
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
actual = competitions.wtac(d, labels)
|
actual = competitions.wtac(d, labels)
|
||||||
desired = torch.tensor([2, 0])
|
desired = torch.tensor([2, 0])
|
||||||
@ -86,7 +86,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_wtac_unequal_dist(self):
|
def test_wtac_unequal_dist(self):
|
||||||
d = torch.tensor([[2., 3., 4.], [2., 3., 1.]])
|
d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
|
||||||
labels = torch.tensor([0, 1, 1])
|
labels = torch.tensor([0, 1, 1])
|
||||||
actual = competitions.wtac(d, labels)
|
actual = competitions.wtac(d, labels)
|
||||||
desired = torch.tensor([0, 1])
|
desired = torch.tensor([0, 1])
|
||||||
@ -96,7 +96,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_wtac_one_hot(self):
|
def test_wtac_one_hot(self):
|
||||||
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
|
||||||
labels = torch.tensor([[0, 1], [1, 0]])
|
labels = torch.tensor([[0, 1], [1, 0]])
|
||||||
actual = competitions.wtac(d, labels)
|
actual = competitions.wtac(d, labels)
|
||||||
desired = torch.tensor([[0, 1], [1, 0]])
|
desired = torch.tensor([[0, 1], [1, 0]])
|
||||||
@ -106,38 +106,38 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_stratified_min(self):
|
def test_stratified_min(self):
|
||||||
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 0, 1, 2])
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = competitions.stratified_min(d, labels)
|
||||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_stratified_min_one_hot(self):
|
def test_stratified_min_one_hot(self):
|
||||||
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 0, 1, 2])
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
labels = torch.eye(3)[labels]
|
labels = torch.eye(3)[labels]
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = competitions.stratified_min(d, labels)
|
||||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_stratified_min_simple(self):
|
def test_stratified_min_simple(self):
|
||||||
d = torch.tensor([[0., 2., 3.], [8., 0, 1]])
|
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
||||||
labels = torch.tensor([0, 1, 2])
|
labels = torch.tensor([0, 1, 2])
|
||||||
actual = competitions.stratified_min(d, labels)
|
actual = competitions.stratified_min(d, labels)
|
||||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_knnc_k1(self):
|
def test_knnc_k1(self):
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
|
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
|
||||||
desired = torch.tensor([2, 0])
|
desired = torch.tensor([2, 0])
|
||||||
@ -194,12 +194,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=2,
|
p=2,
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)**2
|
)**2)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=2)
|
decimal=2)
|
||||||
@ -254,14 +254,14 @@ class TestDistances(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_lpnorm_pinf(self):
|
def test_lpnorm_pinf(self):
|
||||||
actual = distances.lpnorm_distance(self.x, self.y, p=float('inf'))
|
actual = distances.lpnorm_distance(self.x, self.y, p=float("inf"))
|
||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=float('inf'),
|
p=float("inf"),
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)
|
)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
@ -275,12 +275,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=2,
|
p=2,
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)**2
|
)**2)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=2)
|
decimal=2)
|
||||||
@ -293,12 +293,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
desired = torch.empty(self.nx, self.ny)
|
desired = torch.empty(self.nx, self.ny)
|
||||||
for i in range(self.nx):
|
for i in range(self.nx):
|
||||||
for j in range(self.ny):
|
for j in range(self.ny):
|
||||||
desired[i][j] = torch.nn.functional.pairwise_distance(
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
||||||
self.x[i].reshape(1, -1),
|
self.x[i].reshape(1, -1),
|
||||||
self.y[j].reshape(1, -1),
|
self.y[j].reshape(1, -1),
|
||||||
p=2,
|
p=2,
|
||||||
keepdim=False,
|
keepdim=False,
|
||||||
)**2
|
)**2)
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=2)
|
decimal=2)
|
||||||
@ -311,8 +311,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
class TestInitializers(unittest.TestCase):
|
class TestInitializers(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.flist = [
|
self.flist = [
|
||||||
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
|
"zeros",
|
||||||
'stratified_random'
|
"ones",
|
||||||
|
"rand",
|
||||||
|
"randn",
|
||||||
|
"stratified_mean",
|
||||||
|
"stratified_random",
|
||||||
]
|
]
|
||||||
self.x = torch.tensor(
|
self.x = torch.tensor(
|
||||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
||||||
@ -340,7 +344,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
self.assertEqual(1, f(1))
|
self.assertEqual(1, f(1))
|
||||||
|
|
||||||
def test_unknown_deserialization(self):
|
def test_unknown_deserialization(self):
|
||||||
for funcname in ['blubb', 'foobar']:
|
for funcname in ["blubb", "foobar"]:
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = initializers.get_initializer(funcname)
|
_ = initializers.get_initializer(funcname)
|
||||||
|
|
||||||
@ -383,7 +387,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_equal1(self):
|
def test_stratified_mean_equal1(self):
|
||||||
pdist = torch.tensor([1, 1])
|
pdist = torch.tensor([1, 1])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]])
|
desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -393,7 +397,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([1, 1])
|
pdist = torch.tensor([1, 1])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
False)
|
False)
|
||||||
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.]])
|
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -402,8 +406,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_equal2(self):
|
def test_stratified_mean_equal2(self):
|
||||||
pdist = torch.tensor([2, 2])
|
pdist = torch.tensor([2, 2])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.],
|
desired = torch.tensor([[5.0, 5.0, 5.0], [5.0, 5.0, 5.0],
|
||||||
[1., 1., 1.]])
|
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -413,8 +417,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([2, 2])
|
pdist = torch.tensor([2, 2])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
False)
|
False)
|
||||||
desired = torch.tensor([[0., -1., -2.], [0., -1., -2.], [0., 0., 0.],
|
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, -1.0, -2.0],
|
||||||
[0., 0., 0.]])
|
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -423,8 +427,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_unequal(self):
|
def test_stratified_mean_unequal(self):
|
||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
desired = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
|
||||||
[1., 1., 1.]])
|
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -434,8 +438,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
False)
|
False)
|
||||||
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
desired = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
|
||||||
[0., 0., 0.]])
|
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@ -444,8 +448,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def test_stratified_mean_unequal_one_hot(self):
|
def test_stratified_mean_unequal_one_hot(self):
|
||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
y = torch.eye(2)[self.y]
|
y = torch.eye(2)[self.y]
|
||||||
desired1 = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
desired1 = torch.tensor([[5.0, 5.0, 5.0], [1.0, 1.0, 1.0],
|
||||||
[1., 1., 1.]])
|
[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]])
|
||||||
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
|
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
|
||||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||||
@ -460,8 +464,8 @@ class TestInitializers(unittest.TestCase):
|
|||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
y = torch.eye(2)[self.y]
|
y = torch.eye(2)[self.y]
|
||||||
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
|
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
|
||||||
desired1 = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
desired1 = torch.tensor([[0.0, -1.0, -2.0], [0.0, 0.0, 0.0],
|
||||||
[0., 0., 0.]])
|
[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
|
||||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||||
desired1,
|
desired1,
|
||||||
|
@ -29,10 +29,12 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
|
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
|
||||||
|
|
||||||
def test_prototypes1d_init_without_pdist(self):
|
def test_prototypes1d_init_without_pdist(self):
|
||||||
p1 = prototypes.Prototypes1D(input_dim=6,
|
p1 = prototypes.Prototypes1D(
|
||||||
nclasses=2,
|
input_dim=6,
|
||||||
prototypes_per_class=4,
|
nclasses=2,
|
||||||
prototype_initializer='ones')
|
prototypes_per_class=4,
|
||||||
|
prototype_initializer="ones",
|
||||||
|
)
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.ones(8, 6)
|
desired = torch.ones(8, 6)
|
||||||
@ -45,7 +47,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
pdist = [2, 2]
|
pdist = [2, 2]
|
||||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||||
prototype_distribution=pdist,
|
prototype_distribution=pdist,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer="zeros")
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(4, 3)
|
desired = torch.zeros(4, 3)
|
||||||
@ -60,14 +62,15 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=3,
|
input_dim=3,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=None)
|
data=None,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_torch_pdist(self):
|
def test_prototypes1d_init_torch_pdist(self):
|
||||||
pdist = torch.tensor([2, 2])
|
pdist = torch.tensor([2, 2])
|
||||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||||
prototype_distribution=pdist,
|
prototype_distribution=pdist,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer="zeros")
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(4, 3)
|
desired = torch.zeros(4, 3)
|
||||||
@ -77,24 +80,30 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_prototypes1d_init_without_inputdim_with_data(self):
|
def test_prototypes1d_init_without_inputdim_with_data(self):
|
||||||
_ = prototypes.Prototypes1D(nclasses=2,
|
_ = prototypes.Prototypes1D(
|
||||||
prototypes_per_class=1,
|
nclasses=2,
|
||||||
prototype_initializer='stratified_mean',
|
prototypes_per_class=1,
|
||||||
data=[[[1.], [0.]], [1, 0]])
|
prototype_initializer="stratified_mean",
|
||||||
|
data=[[[1.0], [0.0]], [1, 0]],
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_with_int_data(self):
|
def test_prototypes1d_init_with_int_data(self):
|
||||||
_ = prototypes.Prototypes1D(nclasses=2,
|
_ = prototypes.Prototypes1D(
|
||||||
prototypes_per_class=1,
|
nclasses=2,
|
||||||
prototype_initializer='stratified_mean',
|
prototypes_per_class=1,
|
||||||
data=[[[1], [0]], [1, 0]])
|
prototype_initializer="stratified_mean",
|
||||||
|
data=[[[1], [0]], [1, 0]],
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_without_data(self):
|
def test_prototypes1d_init_one_hot_without_data(self):
|
||||||
_ = prototypes.Prototypes1D(input_dim=1,
|
_ = prototypes.Prototypes1D(
|
||||||
nclasses=2,
|
input_dim=1,
|
||||||
prototypes_per_class=1,
|
nclasses=2,
|
||||||
prototype_initializer='stratified_mean',
|
prototypes_per_class=1,
|
||||||
data=None,
|
prototype_initializer="stratified_mean",
|
||||||
one_hot_labels=True)
|
data=None,
|
||||||
|
one_hot_labels=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_labels_false(self):
|
def test_prototypes1d_init_one_hot_labels_false(self):
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
|
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
|
||||||
@ -105,9 +114,10 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=([[0.], [1.]], [[0, 1], [1, 0]]),
|
data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
|
||||||
one_hot_labels=False)
|
one_hot_labels=False,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
|
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
||||||
@ -118,9 +128,10 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=([[0.], [1.]], [0, 1]),
|
data=([[0.0], [1.0]], [0, 1]),
|
||||||
one_hot_labels=True)
|
one_hot_labels=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_labels_true(self):
|
def test_prototypes1d_init_one_hot_labels_true(self):
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
||||||
@ -132,25 +143,27 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=([[0.], [1.]], [[0], [1]]),
|
data=([[0.0], [1.0]], [[0], [1]]),
|
||||||
one_hot_labels=True)
|
one_hot_labels=True,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_with_int_dtype(self):
|
def test_prototypes1d_init_with_int_dtype(self):
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1], [0]], [1, 0]],
|
data=[[[1], [0]], [1, 0]],
|
||||||
dtype=torch.int32)
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_inputndim_with_data(self):
|
def test_prototypes1d_inputndim_with_data(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(input_dim=1,
|
_ = prototypes.Prototypes1D(input_dim=1,
|
||||||
nclasses=1,
|
nclasses=1,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
data=[[1.], [1]])
|
data=[[1.0], [1]])
|
||||||
|
|
||||||
def test_prototypes1d_inputdim_with_data(self):
|
def test_prototypes1d_inputdim_with_data(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@ -158,8 +171,9 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=2,
|
input_dim=2,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1.], [0.]], [1, 0]])
|
data=[[[1.0], [0.0]], [1, 0]],
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_nclasses_with_data(self):
|
def test_prototypes1d_nclasses_with_data(self):
|
||||||
"""Test ValueError raise if provided `nclasses` is not the same
|
"""Test ValueError raise if provided `nclasses` is not the same
|
||||||
@ -170,13 +184,14 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=1,
|
nclasses=1,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer='stratified_mean',
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1.], [2.]], [1, 2]])
|
data=[[[1.0], [2.0]], [1, 2]],
|
||||||
|
)
|
||||||
|
|
||||||
def test_prototypes1d_init_with_ppc(self):
|
def test_prototypes1d_init_with_ppc(self):
|
||||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||||
prototypes_per_class=2,
|
prototypes_per_class=2,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer="zeros")
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(4, 3)
|
desired = torch.zeros(4, 3)
|
||||||
@ -186,9 +201,11 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_prototypes1d_init_with_pdist(self):
|
def test_prototypes1d_init_with_pdist(self):
|
||||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
p1 = prototypes.Prototypes1D(
|
||||||
prototype_distribution=[6, 9],
|
data=[self.x, self.y],
|
||||||
prototype_initializer='zeros')
|
prototype_distribution=[6, 9],
|
||||||
|
prototype_initializer="zeros",
|
||||||
|
)
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(15, 3)
|
desired = torch.zeros(15, 3)
|
||||||
@ -201,10 +218,12 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
def my_initializer(*args, **kwargs):
|
def my_initializer(*args, **kwargs):
|
||||||
return torch.full((2, 99), 99.0), torch.tensor([0, 1])
|
return torch.full((2, 99), 99.0), torch.tensor([0, 1])
|
||||||
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=99,
|
p1 = prototypes.Prototypes1D(
|
||||||
nclasses=2,
|
input_dim=99,
|
||||||
prototypes_per_class=1,
|
nclasses=2,
|
||||||
prototype_initializer=my_initializer)
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer=my_initializer,
|
||||||
|
)
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = 99 * torch.ones(2, 99)
|
desired = 99 * torch.ones(2, 99)
|
||||||
@ -231,7 +250,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
def test_prototypes1d_validate_extra_repr_not_empty(self):
|
def test_prototypes1d_validate_extra_repr_not_empty(self):
|
||||||
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
||||||
rep = p1.extra_repr()
|
rep = p1.extra_repr()
|
||||||
self.assertNotEqual(rep, '')
|
self.assertNotEqual(rep, "")
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
del self.x, self.y, self.gen
|
del self.x, self.y, self.gen
|
||||||
@ -243,11 +262,11 @@ class TestLosses(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_glvqloss_init(self):
|
def test_glvqloss_init(self):
|
||||||
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
|
_ = losses.GLVQLoss(0, "swish_beta", beta=20)
|
||||||
|
|
||||||
def test_glvqloss_forward_1ppc(self):
|
def test_glvqloss_forward_1ppc(self):
|
||||||
criterion = losses.GLVQLoss(margin=0,
|
criterion = losses.GLVQLoss(margin=0,
|
||||||
squashing='sigmoid_beta',
|
squashing="sigmoid_beta",
|
||||||
beta=100)
|
beta=100)
|
||||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
labels = torch.tensor([0, 1])
|
labels = torch.tensor([0, 1])
|
||||||
@ -259,7 +278,7 @@ class TestLosses(unittest.TestCase):
|
|||||||
|
|
||||||
def test_glvqloss_forward_2ppc(self):
|
def test_glvqloss_forward_2ppc(self):
|
||||||
criterion = losses.GLVQLoss(margin=0,
|
criterion = losses.GLVQLoss(margin=0,
|
||||||
squashing='sigmoid_beta',
|
squashing="sigmoid_beta",
|
||||||
beta=100)
|
beta=100)
|
||||||
d = torch.stack([
|
d = torch.stack([
|
||||||
torch.ones(100),
|
torch.ones(100),
|
||||||
|
Loading…
Reference in New Issue
Block a user