Use 'num_' in all variable names

This commit is contained in:
Alexander Engelsberger 2021-05-25 15:57:05 +02:00
parent aff7a385a3
commit 73e6fe384e
11 changed files with 84 additions and 89 deletions

View File

@ -3,14 +3,13 @@
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
# Prepare and preprocess the data # Prepare and preprocess the data
scaler = StandardScaler() scaler = StandardScaler()
@ -28,7 +27,7 @@ class Model(torch.nn.Module):
self.proto_layer = Prototypes1D( self.proto_layer = Prototypes1D(
input_dim=2, input_dim=2,
prototypes_per_class=3, prototypes_per_class=3,
nclasses=3, num_classes=3,
prototype_initializer="stratified_random", prototype_initializer="stratified_random",
data=[x_train, y_train], data=[x_train, y_train],
) )

View File

@ -2,13 +2,12 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
from torch.utils.data import DataLoader
from prototorch.datasets.tecator import Tecator from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed from prototorch.functions.distances import sed
from prototorch.modules import Prototypes1D from prototorch.modules import Prototypes1D
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import get_legend_handles from prototorch.utils.colors import get_legend_handles
from torch.utils.data import DataLoader
# Prepare the dataset and dataloader # Prepare the dataset and dataloader
train_data = Tecator(root="./artifacts", train=True) train_data = Tecator(root="./artifacts", train=True)
@ -23,7 +22,7 @@ class Model(torch.nn.Module):
self.p1 = Prototypes1D( self.p1 = Prototypes1D(
input_dim=100, input_dim=100,
prototypes_per_class=2, prototypes_per_class=2,
nclasses=2, num_classes=2,
prototype_initializer="stratified_random", prototype_initializer="stratified_random",
data=[x, y], data=[x, y],
) )

View File

@ -12,14 +12,13 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from torchvision import transforms
from prototorch.functions.helper import calculate_prototype_accuracy from prototorch.functions.helper import calculate_prototype_accuracy
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.modules.models import GTLVQ from prototorch.modules.models import GTLVQ
from torchvision import transforms
# Parameters and options # Parameters and options
n_epochs = 50 num_epochs = 50
batch_size_train = 64 batch_size_train = 64
batch_size_test = 1000 batch_size_test = 1000
learning_rate = 0.1 learning_rate = 0.1
@ -141,7 +140,7 @@ optimizer = torch.optim.Adam(
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10) criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
# Training loop # Training loop
for epoch in range(n_epochs): for epoch in range(num_epochs):
for batch_idx, (x_train, y_train) in enumerate(train_loader): for batch_idx, (x_train, y_train) in enumerate(train_loader):
model.train() model.train()
x_train, y_train = x_train.to(device), y_train.to(device) x_train, y_train = x_train.to(device), y_train.to(device)
@ -161,7 +160,7 @@ for epoch in range(n_epochs):
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
acc = calculate_prototype_accuracy(distances, y_train, plabels) acc = calculate_prototype_accuracy(distances, y_train, plabels)
print( print(
f"Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \ f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
Train Acc: {acc.item():02.02f}") Train Acc: {acc.item():02.02f}")
# Test # Test

View File

@ -139,8 +139,8 @@ class ReasoningComponents(Components):
def _initialize_reasonings(self, reasonings): def _initialize_reasonings(self, reasonings):
if type(reasonings) == tuple: if type(reasonings) == tuple:
nclasses, ncomps = reasonings num_classes, ncomps = reasonings
reasonings = ZeroReasoningsInitializer(nclasses, ncomps) reasonings = ZeroReasoningsInitializer(num_classes, ncomps)
_reasonings = reasonings.generate() _reasonings = reasonings.generate()
self.register_parameter("_reasonings", _reasonings) self.register_parameter("_reasonings", _reasonings)

View File

@ -4,7 +4,7 @@ import numpy as np
import torch import torch
def make_spiral(n_samples=500, noise=0.3): def make_spiral(num_samples=500, noise=0.3):
"""Generates the Spiral Dataset. """Generates the Spiral Dataset.
For use in Prototorch use `prototorch.datasets.Spiral` instead. For use in Prototorch use `prototorch.datasets.Spiral` instead.
@ -12,14 +12,14 @@ def make_spiral(n_samples=500, noise=0.3):
def get_samples(n, delta_t): def get_samples(n, delta_t):
points = [] points = []
for i in range(n): for i in range(n):
r = i / n_samples * 5 r = i / num_samples * 5
t = 1.75 * i / n * 2 * np.pi + delta_t t = 1.75 * i / n * 2 * np.pi + delta_t
x = r * np.sin(t) + np.random.rand(1) * noise x = r * np.sin(t) + np.random.rand(1) * noise
y = r * np.cos(t) + np.random.rand(1) * noise y = r * np.cos(t) + np.random.rand(1) * noise
points.append([x, y]) points.append([x, y])
return points return points
n = n_samples // 2 n = num_samples // 2
positive = get_samples(n=n, delta_t=0) positive = get_samples(n=n, delta_t=0)
negative = get_samples(n=n, delta_t=np.pi) negative = get_samples(n=n, delta_t=np.pi)
x = np.concatenate( x = np.concatenate(
@ -45,13 +45,13 @@ class Spiral(torch.utils.data.TensorDataset):
- test size - test size
* - 2 * - 2
- 2 - 2
- n_samples - num_samples
- 0 - 0
- 0 - 0
:param n_samples: number of random samples :param num_samples: number of random samples
:param noise: noise added to the spirals :param noise: noise added to the spirals
""" """
def __init__(self, n_samples: int = 500, noise: float = 0.3): def __init__(self, num_samples: int = 500, noise: float = 0.3):
x, y = make_spiral(n_samples, noise) x, y = make_spiral(num_samples, noise)
super().__init__(torch.Tensor(x), torch.LongTensor(y)) super().__init__(torch.Tensor(x), torch.LongTensor(y))

View File

@ -5,12 +5,12 @@ import torch
def stratified_min(distances, labels): def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0) clabels = torch.unique(labels, dim=0)
nclasses = clabels.size()[0] num_classes = clabels.size()[0]
if distances.size()[1] == nclasses: if distances.size()[1] == num_classes:
# skip if only one prototype per class # skip if only one prototype per class
return distances return distances
batch_size = distances.size()[0] batch_size = distances.size()[0]
winning_distances = torch.zeros(nclasses, batch_size) winning_distances = torch.zeros(num_classes, batch_size)
inf = torch.full_like(distances.T, fill_value=float("inf")) inf = torch.full_like(distances.T, fill_value=float("inf"))
# distances_to_wpluses = torch.where(matcher, distances, inf) # distances_to_wpluses = torch.where(matcher, distances, inf)
for i, cl in enumerate(clabels): for i, cl in enumerate(clabels):
@ -18,7 +18,7 @@ def stratified_min(distances, labels):
matcher = torch.eq(labels.unsqueeze(dim=1), cl) matcher = torch.eq(labels.unsqueeze(dim=1), cl)
if labels.ndim == 2: if labels.ndim == 2:
# if the labels are one-hot vectors # if the labels are one-hot vectors
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses) matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
cdists = torch.where(matcher, distances.T, inf).T cdists = torch.where(matcher, distances.T, inf).T
winning_distances[i] = torch.min(cdists, dim=1, winning_distances[i] = torch.min(cdists, dim=1,
keepdim=True).values.squeeze() keepdim=True).values.squeeze()

View File

@ -15,59 +15,59 @@ def register_initializer(function):
def labels_from(distribution, one_hot=True): def labels_from(distribution, one_hot=True):
"""Takes a distribution tensor and returns a labels tensor.""" """Takes a distribution tensor and returns a labels tensor."""
nclasses = distribution.shape[0] num_classes = distribution.shape[0]
llist = [[i] * n for i, n in zip(range(nclasses), distribution)] llist = [[i] * n for i, n in zip(range(num_classes), distribution)]
# labels = [l for cl in llist for l in cl] # flatten the list of lists # labels = [l for cl in llist for l in cl] # flatten the list of lists
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
plabels = torch.tensor(flat_llist, requires_grad=False) plabels = torch.tensor(flat_llist, requires_grad=False)
if one_hot: if one_hot:
return torch.eye(nclasses)[plabels] return torch.eye(num_classes)[plabels]
return plabels return plabels
@register_initializer @register_initializer
def ones(x_train, y_train, prototype_distribution, one_hot=True): def ones(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution) num_protos = torch.sum(prototype_distribution)
protos = torch.ones(nprotos, *x_train.shape[1:]) protos = torch.ones(num_protos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def zeros(x_train, y_train, prototype_distribution, one_hot=True): def zeros(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution) num_protos = torch.sum(prototype_distribution)
protos = torch.zeros(nprotos, *x_train.shape[1:]) protos = torch.zeros(num_protos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def rand(x_train, y_train, prototype_distribution, one_hot=True): def rand(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution) num_protos = torch.sum(prototype_distribution)
protos = torch.rand(nprotos, *x_train.shape[1:]) protos = torch.rand(num_protos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def randn(x_train, y_train, prototype_distribution, one_hot=True): def randn(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution) num_protos = torch.sum(prototype_distribution)
protos = torch.randn(nprotos, *x_train.shape[1:]) protos = torch.randn(num_protos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True): def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution) num_protos = torch.sum(prototype_distribution)
pdim = x_train.shape[1] pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim) protos = torch.empty(num_protos, pdim)
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
for i, label in enumerate(plabels): for i, label in enumerate(plabels):
matcher = torch.eq(label.unsqueeze(dim=0), y_train) matcher = torch.eq(label.unsqueeze(dim=0), y_train)
if one_hot: if one_hot:
nclasses = y_train.size()[1] num_classes = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses) matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
xl = x_train[matcher] xl = x_train[matcher]
mean_xl = torch.mean(xl, dim=0) mean_xl = torch.mean(xl, dim=0)
protos[i] = mean_xl protos[i] = mean_xl
@ -81,15 +81,15 @@ def stratified_random(x_train,
prototype_distribution, prototype_distribution,
one_hot=True, one_hot=True,
epsilon=1e-7): epsilon=1e-7):
nprotos = torch.sum(prototype_distribution) num_protos = torch.sum(prototype_distribution)
pdim = x_train.shape[1] pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim) protos = torch.empty(num_protos, pdim)
plabels = labels_from(prototype_distribution, one_hot) plabels = labels_from(prototype_distribution, one_hot)
for i, label in enumerate(plabels): for i, label in enumerate(plabels):
matcher = torch.eq(label.unsqueeze(dim=0), y_train) matcher = torch.eq(label.unsqueeze(dim=0), y_train)
if one_hot: if one_hot:
nclasses = y_train.size()[1] num_classes = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses) matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
xl = x_train[matcher] xl = x_train[matcher]
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1) rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
random_xl = xl[rand_index] random_xl = xl[rand_index]

View File

@ -8,8 +8,8 @@ def _get_matcher(targets, labels):
matcher = torch.eq(targets.unsqueeze(dim=1), labels) matcher = torch.eq(targets.unsqueeze(dim=1), labels)
if labels.ndim == 2: if labels.ndim == 2:
# if the labels are one-hot vectors # if the labels are one-hot vectors
nclasses = targets.size()[1] num_classes = targets.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses) matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
return matcher return matcher

View File

@ -1,11 +1,10 @@
import torch import torch
from torch import nn
from prototorch.functions.distances import (euclidean_distance_matrix, from prototorch.functions.distances import (euclidean_distance_matrix,
tangent_distance) tangent_distance)
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D
from torch import nn
class GTLVQ(nn.Module): class GTLVQ(nn.Module):
@ -99,7 +98,7 @@ class GTLVQ(nn.Module):
self.cls = Prototypes1D( self.cls = Prototypes1D(
input_dim=feature_dim, input_dim=feature_dim,
prototypes_per_class=prototypes_per_class, prototypes_per_class=prototypes_per_class,
nclasses=num_classes, num_classes=num_classes,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=prototype_data, data=prototype_data,
) )

View File

@ -3,7 +3,6 @@
import warnings import warnings
import torch import torch
from prototorch.functions.initializers import get_initializer from prototorch.functions.initializers import get_initializer
@ -53,13 +52,13 @@ class Prototypes1D(_Prototypes):
raise NameError("`input_dim` required if " raise NameError("`input_dim` required if "
"no `data` is provided.") "no `data` is provided.")
if prototype_distribution: if prototype_distribution:
kwargs_nclasses = sum(prototype_distribution) kwargs_num_classes = sum(prototype_distribution)
else: else:
if "nclasses" not in kwargs: if "num_classes" not in kwargs:
raise NameError("`prototype_distribution` required if " raise NameError("`prototype_distribution` required if "
"both `data` and `nclasses` are not " "both `data` and `num_classes` are not "
"provided.") "provided.")
kwargs_nclasses = kwargs.pop("nclasses") kwargs_num_classes = kwargs.pop("num_classes")
input_dim = kwargs.pop("input_dim") input_dim = kwargs.pop("input_dim")
if prototype_initializer in [ if prototype_initializer in [
"stratified_mean", "stratified_random" "stratified_mean", "stratified_random"
@ -68,18 +67,18 @@ class Prototypes1D(_Prototypes):
f"`prototype_initializer`: `{prototype_initializer}` " f"`prototype_initializer`: `{prototype_initializer}` "
"requires `data`, but `data` is not provided. " "requires `data`, but `data` is not provided. "
"Using randomly generated data instead.") "Using randomly generated data instead.")
x_train = torch.rand(kwargs_nclasses, input_dim) x_train = torch.rand(kwargs_num_classes, input_dim)
y_train = torch.arange(kwargs_nclasses) y_train = torch.arange(kwargs_num_classes)
if one_hot_labels: if one_hot_labels:
y_train = torch.eye(kwargs_nclasses)[y_train] y_train = torch.eye(kwargs_num_classes)[y_train]
data = [x_train, y_train] data = [x_train, y_train]
x_train, y_train = data x_train, y_train = data
x_train = torch.as_tensor(x_train).type(dtype) x_train = torch.as_tensor(x_train).type(dtype)
y_train = torch.as_tensor(y_train).type(torch.int) y_train = torch.as_tensor(y_train).type(torch.int)
nclasses = torch.unique(y_train, dim=-1).shape[-1] num_classes = torch.unique(y_train, dim=-1).shape[-1]
if nclasses == 1: if num_classes == 1:
warnings.warn("Are you sure about having one class only?") warnings.warn("Are you sure about having one class only?")
if x_train.ndim != 2: if x_train.ndim != 2:
@ -105,19 +104,20 @@ class Prototypes1D(_Prototypes):
"not match data dimension " "not match data dimension "
f"`data[0].shape[1]`={x_train.shape[1]}") f"`data[0].shape[1]`={x_train.shape[1]}")
# Verify the number of classes if `nclasses` is provided # Verify the number of classes if `num_classes` is provided
if "nclasses" in kwargs: if "num_classes" in kwargs:
kwargs_nclasses = kwargs.pop("nclasses") kwargs_num_classes = kwargs.pop("num_classes")
if kwargs_nclasses != nclasses: if kwargs_num_classes != num_classes:
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does " raise ValueError(
"not match data labels " f"Provided `num_classes={kwargs_num_classes}` does "
"`torch.unique(data[1]).shape[0]`" "not match data labels "
f"={nclasses}") "`torch.unique(data[1]).shape[0]`"
f"={num_classes}")
super().__init__(**kwargs) super().__init__(**kwargs)
if not prototype_distribution: if not prototype_distribution:
prototype_distribution = [prototypes_per_class] * nclasses prototype_distribution = [prototypes_per_class] * num_classes
with torch.no_grad(): with torch.no_grad():
self.prototype_distribution = torch.tensor(prototype_distribution) self.prototype_distribution = torch.tensor(prototype_distribution)

View File

@ -4,7 +4,6 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from prototorch.modules import losses, prototypes from prototorch.modules import losses, prototypes
@ -18,20 +17,20 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_init_without_input_dim(self): def test_prototypes1d_init_without_input_dim(self):
with self.assertRaises(NameError): with self.assertRaises(NameError):
_ = prototypes.Prototypes1D(nclasses=2) _ = prototypes.Prototypes1D(num_classes=2)
def test_prototypes1d_init_without_nclasses(self): def test_prototypes1d_init_without_num_classes(self):
with self.assertRaises(NameError): with self.assertRaises(NameError):
_ = prototypes.Prototypes1D(input_dim=1) _ = prototypes.Prototypes1D(input_dim=1)
def test_prototypes1d_init_with_nclasses_1(self): def test_prototypes1d_init_with_num_classes_1(self):
with self.assertWarns(UserWarning): with self.assertWarns(UserWarning):
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1) _ = prototypes.Prototypes1D(num_classes=1, input_dim=1)
def test_prototypes1d_init_without_pdist(self): def test_prototypes1d_init_without_pdist(self):
p1 = prototypes.Prototypes1D( p1 = prototypes.Prototypes1D(
input_dim=6, input_dim=6,
nclasses=2, num_classes=2,
prototypes_per_class=4, prototypes_per_class=4,
prototype_initializer="ones", prototype_initializer="ones",
) )
@ -60,7 +59,7 @@ class TestPrototypes(unittest.TestCase):
with self.assertWarns(UserWarning): with self.assertWarns(UserWarning):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
input_dim=3, input_dim=3,
nclasses=2, num_classes=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=None, data=None,
@ -81,7 +80,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_init_without_inputdim_with_data(self): def test_prototypes1d_init_without_inputdim_with_data(self):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
nclasses=2, num_classes=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=[[[1.0], [0.0]], [1, 0]], data=[[[1.0], [0.0]], [1, 0]],
@ -89,7 +88,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_init_with_int_data(self): def test_prototypes1d_init_with_int_data(self):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
nclasses=2, num_classes=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=[[[1], [0]], [1, 0]], data=[[[1], [0]], [1, 0]],
@ -98,7 +97,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_init_one_hot_without_data(self): def test_prototypes1d_init_one_hot_without_data(self):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
input_dim=1, input_dim=1,
nclasses=2, num_classes=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=None, data=None,
@ -112,7 +111,7 @@ class TestPrototypes(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
input_dim=1, input_dim=1,
nclasses=2, num_classes=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=([[0.0], [1.0]], [[0, 1], [1, 0]]), data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
@ -126,7 +125,7 @@ class TestPrototypes(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
input_dim=1, input_dim=1,
nclasses=2, num_classes=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=([[0.0], [1.0]], [0, 1]), data=([[0.0], [1.0]], [0, 1]),
@ -141,7 +140,7 @@ class TestPrototypes(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
input_dim=1, input_dim=1,
nclasses=2, num_classes=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=([[0.0], [1.0]], [[0], [1]]), data=([[0.0], [1.0]], [[0], [1]]),
@ -151,7 +150,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_init_with_int_dtype(self): def test_prototypes1d_init_with_int_dtype(self):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
nclasses=2, num_classes=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=[[[1], [0]], [1, 0]], data=[[[1], [0]], [1, 0]],
@ -161,7 +160,7 @@ class TestPrototypes(unittest.TestCase):
def test_prototypes1d_inputndim_with_data(self): def test_prototypes1d_inputndim_with_data(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(input_dim=1, _ = prototypes.Prototypes1D(input_dim=1,
nclasses=1, num_classes=1,
prototypes_per_class=1, prototypes_per_class=1,
data=[[1.0], [1]]) data=[[1.0], [1]])
@ -169,20 +168,20 @@ class TestPrototypes(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
input_dim=2, input_dim=2,
nclasses=2, num_classes=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=[[[1.0], [0.0]], [1, 0]], data=[[[1.0], [0.0]], [1, 0]],
) )
def test_prototypes1d_nclasses_with_data(self): def test_prototypes1d_num_classes_with_data(self):
"""Test ValueError raise if provided `nclasses` is not the same """Test ValueError raise if provided `num_classes` is not the same
as the one computed from the provided `data`. as the one computed from the provided `data`.
""" """
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D( _ = prototypes.Prototypes1D(
input_dim=1, input_dim=1,
nclasses=1, num_classes=1,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer="stratified_mean", prototype_initializer="stratified_mean",
data=[[[1.0], [2.0]], [1, 2]], data=[[[1.0], [2.0]], [1, 2]],
@ -220,7 +219,7 @@ class TestPrototypes(unittest.TestCase):
p1 = prototypes.Prototypes1D( p1 = prototypes.Prototypes1D(
input_dim=99, input_dim=99,
nclasses=2, num_classes=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer=my_initializer, prototype_initializer=my_initializer,
) )