Use 'num_' in all variable names
This commit is contained in:
parent
aff7a385a3
commit
73e6fe384e
@ -3,14 +3,13 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torchinfo import summary
|
||||
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torchinfo import summary
|
||||
|
||||
# Prepare and preprocess the data
|
||||
scaler = StandardScaler()
|
||||
@ -28,7 +27,7 @@ class Model(torch.nn.Module):
|
||||
self.proto_layer = Prototypes1D(
|
||||
input_dim=2,
|
||||
prototypes_per_class=3,
|
||||
nclasses=3,
|
||||
num_classes=3,
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x_train, y_train],
|
||||
)
|
||||
|
@ -2,13 +2,12 @@
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from prototorch.datasets.tecator import Tecator
|
||||
from prototorch.functions.distances import sed
|
||||
from prototorch.modules import Prototypes1D
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.utils.colors import get_legend_handles
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# Prepare the dataset and dataloader
|
||||
train_data = Tecator(root="./artifacts", train=True)
|
||||
@ -23,7 +22,7 @@ class Model(torch.nn.Module):
|
||||
self.p1 = Prototypes1D(
|
||||
input_dim=100,
|
||||
prototypes_per_class=2,
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototype_initializer="stratified_random",
|
||||
data=[x, y],
|
||||
)
|
||||
|
@ -12,14 +12,13 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
|
||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.models import GTLVQ
|
||||
from torchvision import transforms
|
||||
|
||||
# Parameters and options
|
||||
n_epochs = 50
|
||||
num_epochs = 50
|
||||
batch_size_train = 64
|
||||
batch_size_test = 1000
|
||||
learning_rate = 0.1
|
||||
@ -141,7 +140,7 @@ optimizer = torch.optim.Adam(
|
||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(n_epochs):
|
||||
for epoch in range(num_epochs):
|
||||
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
||||
model.train()
|
||||
x_train, y_train = x_train.to(device), y_train.to(device)
|
||||
@ -161,7 +160,7 @@ for epoch in range(n_epochs):
|
||||
if batch_idx % log_interval == 0:
|
||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
||||
print(
|
||||
f"Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||
f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||
Train Acc: {acc.item():02.02f}")
|
||||
|
||||
# Test
|
||||
|
@ -139,8 +139,8 @@ class ReasoningComponents(Components):
|
||||
|
||||
def _initialize_reasonings(self, reasonings):
|
||||
if type(reasonings) == tuple:
|
||||
nclasses, ncomps = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(nclasses, ncomps)
|
||||
num_classes, ncomps = reasonings
|
||||
reasonings = ZeroReasoningsInitializer(num_classes, ncomps)
|
||||
|
||||
_reasonings = reasonings.generate()
|
||||
self.register_parameter("_reasonings", _reasonings)
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def make_spiral(n_samples=500, noise=0.3):
|
||||
def make_spiral(num_samples=500, noise=0.3):
|
||||
"""Generates the Spiral Dataset.
|
||||
|
||||
For use in Prototorch use `prototorch.datasets.Spiral` instead.
|
||||
@ -12,14 +12,14 @@ def make_spiral(n_samples=500, noise=0.3):
|
||||
def get_samples(n, delta_t):
|
||||
points = []
|
||||
for i in range(n):
|
||||
r = i / n_samples * 5
|
||||
r = i / num_samples * 5
|
||||
t = 1.75 * i / n * 2 * np.pi + delta_t
|
||||
x = r * np.sin(t) + np.random.rand(1) * noise
|
||||
y = r * np.cos(t) + np.random.rand(1) * noise
|
||||
points.append([x, y])
|
||||
return points
|
||||
|
||||
n = n_samples // 2
|
||||
n = num_samples // 2
|
||||
positive = get_samples(n=n, delta_t=0)
|
||||
negative = get_samples(n=n, delta_t=np.pi)
|
||||
x = np.concatenate(
|
||||
@ -45,13 +45,13 @@ class Spiral(torch.utils.data.TensorDataset):
|
||||
- test size
|
||||
* - 2
|
||||
- 2
|
||||
- n_samples
|
||||
- num_samples
|
||||
- 0
|
||||
- 0
|
||||
|
||||
:param n_samples: number of random samples
|
||||
:param num_samples: number of random samples
|
||||
:param noise: noise added to the spirals
|
||||
"""
|
||||
def __init__(self, n_samples: int = 500, noise: float = 0.3):
|
||||
x, y = make_spiral(n_samples, noise)
|
||||
def __init__(self, num_samples: int = 500, noise: float = 0.3):
|
||||
x, y = make_spiral(num_samples, noise)
|
||||
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
||||
|
@ -5,12 +5,12 @@ import torch
|
||||
|
||||
def stratified_min(distances, labels):
|
||||
clabels = torch.unique(labels, dim=0)
|
||||
nclasses = clabels.size()[0]
|
||||
if distances.size()[1] == nclasses:
|
||||
num_classes = clabels.size()[0]
|
||||
if distances.size()[1] == num_classes:
|
||||
# skip if only one prototype per class
|
||||
return distances
|
||||
batch_size = distances.size()[0]
|
||||
winning_distances = torch.zeros(nclasses, batch_size)
|
||||
winning_distances = torch.zeros(num_classes, batch_size)
|
||||
inf = torch.full_like(distances.T, fill_value=float("inf"))
|
||||
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||
for i, cl in enumerate(clabels):
|
||||
@ -18,7 +18,7 @@ def stratified_min(distances, labels):
|
||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
cdists = torch.where(matcher, distances.T, inf).T
|
||||
winning_distances[i] = torch.min(cdists, dim=1,
|
||||
keepdim=True).values.squeeze()
|
||||
|
@ -15,59 +15,59 @@ def register_initializer(function):
|
||||
|
||||
def labels_from(distribution, one_hot=True):
|
||||
"""Takes a distribution tensor and returns a labels tensor."""
|
||||
nclasses = distribution.shape[0]
|
||||
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
||||
num_classes = distribution.shape[0]
|
||||
llist = [[i] * n for i, n in zip(range(num_classes), distribution)]
|
||||
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
||||
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
||||
plabels = torch.tensor(flat_llist, requires_grad=False)
|
||||
if one_hot:
|
||||
return torch.eye(nclasses)[plabels]
|
||||
return torch.eye(num_classes)[plabels]
|
||||
return plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.ones(nprotos, *x_train.shape[1:])
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.ones(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.zeros(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.rand(nprotos, *x_train.shape[1:])
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.rand(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.randn(nprotos, *x_train.shape[1:])
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
protos = torch.randn(num_protos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(nprotos, pdim)
|
||||
protos = torch.empty(num_protos, pdim)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
for i, label in enumerate(plabels):
|
||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||
if one_hot:
|
||||
nclasses = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
num_classes = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
xl = x_train[matcher]
|
||||
mean_xl = torch.mean(xl, dim=0)
|
||||
protos[i] = mean_xl
|
||||
@ -81,15 +81,15 @@ def stratified_random(x_train,
|
||||
prototype_distribution,
|
||||
one_hot=True,
|
||||
epsilon=1e-7):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
num_protos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(nprotos, pdim)
|
||||
protos = torch.empty(num_protos, pdim)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
for i, label in enumerate(plabels):
|
||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||
if one_hot:
|
||||
nclasses = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
num_classes = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
xl = x_train[matcher]
|
||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||
random_xl = xl[rand_index]
|
||||
|
@ -8,8 +8,8 @@ def _get_matcher(targets, labels):
|
||||
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
nclasses = targets.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
num_classes = targets.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||
return matcher
|
||||
|
||||
|
||||
|
@ -1,11 +1,10 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||||
tangent_distance)
|
||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
@ -99,7 +98,7 @@ class GTLVQ(nn.Module):
|
||||
self.cls = Prototypes1D(
|
||||
input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
num_classes=num_classes,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=prototype_data,
|
||||
)
|
||||
|
@ -3,7 +3,6 @@
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
|
||||
|
||||
@ -53,13 +52,13 @@ class Prototypes1D(_Prototypes):
|
||||
raise NameError("`input_dim` required if "
|
||||
"no `data` is provided.")
|
||||
if prototype_distribution:
|
||||
kwargs_nclasses = sum(prototype_distribution)
|
||||
kwargs_num_classes = sum(prototype_distribution)
|
||||
else:
|
||||
if "nclasses" not in kwargs:
|
||||
if "num_classes" not in kwargs:
|
||||
raise NameError("`prototype_distribution` required if "
|
||||
"both `data` and `nclasses` are not "
|
||||
"both `data` and `num_classes` are not "
|
||||
"provided.")
|
||||
kwargs_nclasses = kwargs.pop("nclasses")
|
||||
kwargs_num_classes = kwargs.pop("num_classes")
|
||||
input_dim = kwargs.pop("input_dim")
|
||||
if prototype_initializer in [
|
||||
"stratified_mean", "stratified_random"
|
||||
@ -68,18 +67,18 @@ class Prototypes1D(_Prototypes):
|
||||
f"`prototype_initializer`: `{prototype_initializer}` "
|
||||
"requires `data`, but `data` is not provided. "
|
||||
"Using randomly generated data instead.")
|
||||
x_train = torch.rand(kwargs_nclasses, input_dim)
|
||||
y_train = torch.arange(kwargs_nclasses)
|
||||
x_train = torch.rand(kwargs_num_classes, input_dim)
|
||||
y_train = torch.arange(kwargs_num_classes)
|
||||
if one_hot_labels:
|
||||
y_train = torch.eye(kwargs_nclasses)[y_train]
|
||||
y_train = torch.eye(kwargs_num_classes)[y_train]
|
||||
data = [x_train, y_train]
|
||||
|
||||
x_train, y_train = data
|
||||
x_train = torch.as_tensor(x_train).type(dtype)
|
||||
y_train = torch.as_tensor(y_train).type(torch.int)
|
||||
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
||||
num_classes = torch.unique(y_train, dim=-1).shape[-1]
|
||||
|
||||
if nclasses == 1:
|
||||
if num_classes == 1:
|
||||
warnings.warn("Are you sure about having one class only?")
|
||||
|
||||
if x_train.ndim != 2:
|
||||
@ -105,19 +104,20 @@ class Prototypes1D(_Prototypes):
|
||||
"not match data dimension "
|
||||
f"`data[0].shape[1]`={x_train.shape[1]}")
|
||||
|
||||
# Verify the number of classes if `nclasses` is provided
|
||||
if "nclasses" in kwargs:
|
||||
kwargs_nclasses = kwargs.pop("nclasses")
|
||||
if kwargs_nclasses != nclasses:
|
||||
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
|
||||
# Verify the number of classes if `num_classes` is provided
|
||||
if "num_classes" in kwargs:
|
||||
kwargs_num_classes = kwargs.pop("num_classes")
|
||||
if kwargs_num_classes != num_classes:
|
||||
raise ValueError(
|
||||
f"Provided `num_classes={kwargs_num_classes}` does "
|
||||
"not match data labels "
|
||||
"`torch.unique(data[1]).shape[0]`"
|
||||
f"={nclasses}")
|
||||
f"={num_classes}")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not prototype_distribution:
|
||||
prototype_distribution = [prototypes_per_class] * nclasses
|
||||
prototype_distribution = [prototypes_per_class] * num_classes
|
||||
with torch.no_grad():
|
||||
self.prototype_distribution = torch.tensor(prototype_distribution)
|
||||
|
||||
|
@ -4,7 +4,6 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.modules import losses, prototypes
|
||||
|
||||
|
||||
@ -18,20 +17,20 @@ class TestPrototypes(unittest.TestCase):
|
||||
|
||||
def test_prototypes1d_init_without_input_dim(self):
|
||||
with self.assertRaises(NameError):
|
||||
_ = prototypes.Prototypes1D(nclasses=2)
|
||||
_ = prototypes.Prototypes1D(num_classes=2)
|
||||
|
||||
def test_prototypes1d_init_without_nclasses(self):
|
||||
def test_prototypes1d_init_without_num_classes(self):
|
||||
with self.assertRaises(NameError):
|
||||
_ = prototypes.Prototypes1D(input_dim=1)
|
||||
|
||||
def test_prototypes1d_init_with_nclasses_1(self):
|
||||
def test_prototypes1d_init_with_num_classes_1(self):
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
|
||||
_ = prototypes.Prototypes1D(num_classes=1, input_dim=1)
|
||||
|
||||
def test_prototypes1d_init_without_pdist(self):
|
||||
p1 = prototypes.Prototypes1D(
|
||||
input_dim=6,
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=4,
|
||||
prototype_initializer="ones",
|
||||
)
|
||||
@ -60,7 +59,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=3,
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=None,
|
||||
@ -81,7 +80,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
|
||||
def test_prototypes1d_init_without_inputdim_with_data(self):
|
||||
_ = prototypes.Prototypes1D(
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[[[1.0], [0.0]], [1, 0]],
|
||||
@ -89,7 +88,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
|
||||
def test_prototypes1d_init_with_int_data(self):
|
||||
_ = prototypes.Prototypes1D(
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[[[1], [0]], [1, 0]],
|
||||
@ -98,7 +97,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
def test_prototypes1d_init_one_hot_without_data(self):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=None,
|
||||
@ -112,7 +111,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
|
||||
@ -126,7 +125,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=([[0.0], [1.0]], [0, 1]),
|
||||
@ -141,7 +140,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=([[0.0], [1.0]], [[0], [1]]),
|
||||
@ -151,7 +150,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
def test_prototypes1d_init_with_int_dtype(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[[[1], [0]], [1, 0]],
|
||||
@ -161,7 +160,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
def test_prototypes1d_inputndim_with_data(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(input_dim=1,
|
||||
nclasses=1,
|
||||
num_classes=1,
|
||||
prototypes_per_class=1,
|
||||
data=[[1.0], [1]])
|
||||
|
||||
@ -169,20 +168,20 @@ class TestPrototypes(unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=2,
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[[[1.0], [0.0]], [1, 0]],
|
||||
)
|
||||
|
||||
def test_prototypes1d_nclasses_with_data(self):
|
||||
"""Test ValueError raise if provided `nclasses` is not the same
|
||||
def test_prototypes1d_num_classes_with_data(self):
|
||||
"""Test ValueError raise if provided `num_classes` is not the same
|
||||
as the one computed from the provided `data`.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=1,
|
||||
num_classes=1,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=[[[1.0], [2.0]], [1, 2]],
|
||||
@ -220,7 +219,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
|
||||
p1 = prototypes.Prototypes1D(
|
||||
input_dim=99,
|
||||
nclasses=2,
|
||||
num_classes=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer=my_initializer,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user