Use 'num_' in all variable names
This commit is contained in:
parent
aff7a385a3
commit
73e6fe384e
@ -3,14 +3,13 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
from torchinfo import summary
|
|
||||||
|
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from torchinfo import summary
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
# Prepare and preprocess the data
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
@ -28,7 +27,7 @@ class Model(torch.nn.Module):
|
|||||||
self.proto_layer = Prototypes1D(
|
self.proto_layer = Prototypes1D(
|
||||||
input_dim=2,
|
input_dim=2,
|
||||||
prototypes_per_class=3,
|
prototypes_per_class=3,
|
||||||
nclasses=3,
|
num_classes=3,
|
||||||
prototype_initializer="stratified_random",
|
prototype_initializer="stratified_random",
|
||||||
data=[x_train, y_train],
|
data=[x_train, y_train],
|
||||||
)
|
)
|
||||||
|
@ -2,13 +2,12 @@
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from prototorch.datasets.tecator import Tecator
|
from prototorch.datasets.tecator import Tecator
|
||||||
from prototorch.functions.distances import sed
|
from prototorch.functions.distances import sed
|
||||||
from prototorch.modules import Prototypes1D
|
from prototorch.modules import Prototypes1D
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.utils.colors import get_legend_handles
|
from prototorch.utils.colors import get_legend_handles
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
# Prepare the dataset and dataloader
|
# Prepare the dataset and dataloader
|
||||||
train_data = Tecator(root="./artifacts", train=True)
|
train_data = Tecator(root="./artifacts", train=True)
|
||||||
@ -23,7 +22,7 @@ class Model(torch.nn.Module):
|
|||||||
self.p1 = Prototypes1D(
|
self.p1 = Prototypes1D(
|
||||||
input_dim=100,
|
input_dim=100,
|
||||||
prototypes_per_class=2,
|
prototypes_per_class=2,
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototype_initializer="stratified_random",
|
prototype_initializer="stratified_random",
|
||||||
data=[x, y],
|
data=[x, y],
|
||||||
)
|
)
|
||||||
|
@ -12,14 +12,13 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.models import GTLVQ
|
from prototorch.modules.models import GTLVQ
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
# Parameters and options
|
# Parameters and options
|
||||||
n_epochs = 50
|
num_epochs = 50
|
||||||
batch_size_train = 64
|
batch_size_train = 64
|
||||||
batch_size_test = 1000
|
batch_size_test = 1000
|
||||||
learning_rate = 0.1
|
learning_rate = 0.1
|
||||||
@ -141,7 +140,7 @@ optimizer = torch.optim.Adam(
|
|||||||
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
criterion = GLVQLoss(squashing="sigmoid_beta", beta=10)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
for epoch in range(n_epochs):
|
for epoch in range(num_epochs):
|
||||||
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
for batch_idx, (x_train, y_train) in enumerate(train_loader):
|
||||||
model.train()
|
model.train()
|
||||||
x_train, y_train = x_train.to(device), y_train.to(device)
|
x_train, y_train = x_train.to(device), y_train.to(device)
|
||||||
@ -161,7 +160,7 @@ for epoch in range(n_epochs):
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
acc = calculate_prototype_accuracy(distances, y_train, plabels)
|
||||||
print(
|
print(
|
||||||
f"Epoch: {epoch + 1:02d}/{n_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
f"Epoch: {epoch + 1:02d}/{num_epochs:02d} Epoch Progress: {100. * batch_idx / len(train_loader):02.02f} % Loss: {loss.item():02.02f} \
|
||||||
Train Acc: {acc.item():02.02f}")
|
Train Acc: {acc.item():02.02f}")
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
|
@ -139,8 +139,8 @@ class ReasoningComponents(Components):
|
|||||||
|
|
||||||
def _initialize_reasonings(self, reasonings):
|
def _initialize_reasonings(self, reasonings):
|
||||||
if type(reasonings) == tuple:
|
if type(reasonings) == tuple:
|
||||||
nclasses, ncomps = reasonings
|
num_classes, ncomps = reasonings
|
||||||
reasonings = ZeroReasoningsInitializer(nclasses, ncomps)
|
reasonings = ZeroReasoningsInitializer(num_classes, ncomps)
|
||||||
|
|
||||||
_reasonings = reasonings.generate()
|
_reasonings = reasonings.generate()
|
||||||
self.register_parameter("_reasonings", _reasonings)
|
self.register_parameter("_reasonings", _reasonings)
|
||||||
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def make_spiral(n_samples=500, noise=0.3):
|
def make_spiral(num_samples=500, noise=0.3):
|
||||||
"""Generates the Spiral Dataset.
|
"""Generates the Spiral Dataset.
|
||||||
|
|
||||||
For use in Prototorch use `prototorch.datasets.Spiral` instead.
|
For use in Prototorch use `prototorch.datasets.Spiral` instead.
|
||||||
@ -12,14 +12,14 @@ def make_spiral(n_samples=500, noise=0.3):
|
|||||||
def get_samples(n, delta_t):
|
def get_samples(n, delta_t):
|
||||||
points = []
|
points = []
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
r = i / n_samples * 5
|
r = i / num_samples * 5
|
||||||
t = 1.75 * i / n * 2 * np.pi + delta_t
|
t = 1.75 * i / n * 2 * np.pi + delta_t
|
||||||
x = r * np.sin(t) + np.random.rand(1) * noise
|
x = r * np.sin(t) + np.random.rand(1) * noise
|
||||||
y = r * np.cos(t) + np.random.rand(1) * noise
|
y = r * np.cos(t) + np.random.rand(1) * noise
|
||||||
points.append([x, y])
|
points.append([x, y])
|
||||||
return points
|
return points
|
||||||
|
|
||||||
n = n_samples // 2
|
n = num_samples // 2
|
||||||
positive = get_samples(n=n, delta_t=0)
|
positive = get_samples(n=n, delta_t=0)
|
||||||
negative = get_samples(n=n, delta_t=np.pi)
|
negative = get_samples(n=n, delta_t=np.pi)
|
||||||
x = np.concatenate(
|
x = np.concatenate(
|
||||||
@ -45,13 +45,13 @@ class Spiral(torch.utils.data.TensorDataset):
|
|||||||
- test size
|
- test size
|
||||||
* - 2
|
* - 2
|
||||||
- 2
|
- 2
|
||||||
- n_samples
|
- num_samples
|
||||||
- 0
|
- 0
|
||||||
- 0
|
- 0
|
||||||
|
|
||||||
:param n_samples: number of random samples
|
:param num_samples: number of random samples
|
||||||
:param noise: noise added to the spirals
|
:param noise: noise added to the spirals
|
||||||
"""
|
"""
|
||||||
def __init__(self, n_samples: int = 500, noise: float = 0.3):
|
def __init__(self, num_samples: int = 500, noise: float = 0.3):
|
||||||
x, y = make_spiral(n_samples, noise)
|
x, y = make_spiral(num_samples, noise)
|
||||||
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
||||||
|
@ -5,12 +5,12 @@ import torch
|
|||||||
|
|
||||||
def stratified_min(distances, labels):
|
def stratified_min(distances, labels):
|
||||||
clabels = torch.unique(labels, dim=0)
|
clabels = torch.unique(labels, dim=0)
|
||||||
nclasses = clabels.size()[0]
|
num_classes = clabels.size()[0]
|
||||||
if distances.size()[1] == nclasses:
|
if distances.size()[1] == num_classes:
|
||||||
# skip if only one prototype per class
|
# skip if only one prototype per class
|
||||||
return distances
|
return distances
|
||||||
batch_size = distances.size()[0]
|
batch_size = distances.size()[0]
|
||||||
winning_distances = torch.zeros(nclasses, batch_size)
|
winning_distances = torch.zeros(num_classes, batch_size)
|
||||||
inf = torch.full_like(distances.T, fill_value=float("inf"))
|
inf = torch.full_like(distances.T, fill_value=float("inf"))
|
||||||
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||||
for i, cl in enumerate(clabels):
|
for i, cl in enumerate(clabels):
|
||||||
@ -18,7 +18,7 @@ def stratified_min(distances, labels):
|
|||||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||||
if labels.ndim == 2:
|
if labels.ndim == 2:
|
||||||
# if the labels are one-hot vectors
|
# if the labels are one-hot vectors
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
cdists = torch.where(matcher, distances.T, inf).T
|
cdists = torch.where(matcher, distances.T, inf).T
|
||||||
winning_distances[i] = torch.min(cdists, dim=1,
|
winning_distances[i] = torch.min(cdists, dim=1,
|
||||||
keepdim=True).values.squeeze()
|
keepdim=True).values.squeeze()
|
||||||
|
@ -15,59 +15,59 @@ def register_initializer(function):
|
|||||||
|
|
||||||
def labels_from(distribution, one_hot=True):
|
def labels_from(distribution, one_hot=True):
|
||||||
"""Takes a distribution tensor and returns a labels tensor."""
|
"""Takes a distribution tensor and returns a labels tensor."""
|
||||||
nclasses = distribution.shape[0]
|
num_classes = distribution.shape[0]
|
||||||
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
llist = [[i] * n for i, n in zip(range(num_classes), distribution)]
|
||||||
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
||||||
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
||||||
plabels = torch.tensor(flat_llist, requires_grad=False)
|
plabels = torch.tensor(flat_llist, requires_grad=False)
|
||||||
if one_hot:
|
if one_hot:
|
||||||
return torch.eye(nclasses)[plabels]
|
return torch.eye(num_classes)[plabels]
|
||||||
return plabels
|
return plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
protos = torch.ones(nprotos, *x_train.shape[1:])
|
protos = torch.ones(num_protos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
protos = torch.zeros(num_protos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
protos = torch.rand(nprotos, *x_train.shape[1:])
|
protos = torch.rand(num_protos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
protos = torch.randn(nprotos, *x_train.shape[1:])
|
protos = torch.randn(num_protos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
pdim = x_train.shape[1]
|
pdim = x_train.shape[1]
|
||||||
protos = torch.empty(nprotos, pdim)
|
protos = torch.empty(num_protos, pdim)
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
for i, label in enumerate(plabels):
|
for i, label in enumerate(plabels):
|
||||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||||
if one_hot:
|
if one_hot:
|
||||||
nclasses = y_train.size()[1]
|
num_classes = y_train.size()[1]
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
xl = x_train[matcher]
|
xl = x_train[matcher]
|
||||||
mean_xl = torch.mean(xl, dim=0)
|
mean_xl = torch.mean(xl, dim=0)
|
||||||
protos[i] = mean_xl
|
protos[i] = mean_xl
|
||||||
@ -81,15 +81,15 @@ def stratified_random(x_train,
|
|||||||
prototype_distribution,
|
prototype_distribution,
|
||||||
one_hot=True,
|
one_hot=True,
|
||||||
epsilon=1e-7):
|
epsilon=1e-7):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
num_protos = torch.sum(prototype_distribution)
|
||||||
pdim = x_train.shape[1]
|
pdim = x_train.shape[1]
|
||||||
protos = torch.empty(nprotos, pdim)
|
protos = torch.empty(num_protos, pdim)
|
||||||
plabels = labels_from(prototype_distribution, one_hot)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
for i, label in enumerate(plabels):
|
for i, label in enumerate(plabels):
|
||||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||||
if one_hot:
|
if one_hot:
|
||||||
nclasses = y_train.size()[1]
|
num_classes = y_train.size()[1]
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
xl = x_train[matcher]
|
xl = x_train[matcher]
|
||||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||||
random_xl = xl[rand_index]
|
random_xl = xl[rand_index]
|
||||||
|
@ -8,8 +8,8 @@ def _get_matcher(targets, labels):
|
|||||||
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
matcher = torch.eq(targets.unsqueeze(dim=1), labels)
|
||||||
if labels.ndim == 2:
|
if labels.ndim == 2:
|
||||||
# if the labels are one-hot vectors
|
# if the labels are one-hot vectors
|
||||||
nclasses = targets.size()[1]
|
num_classes = targets.size()[1]
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), num_classes)
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||||||
tangent_distance)
|
tangent_distance)
|
||||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||||
from prototorch.functions.normalization import orthogonalization
|
from prototorch.functions.normalization import orthogonalization
|
||||||
from prototorch.modules.prototypes import Prototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
class GTLVQ(nn.Module):
|
class GTLVQ(nn.Module):
|
||||||
@ -99,7 +98,7 @@ class GTLVQ(nn.Module):
|
|||||||
self.cls = Prototypes1D(
|
self.cls = Prototypes1D(
|
||||||
input_dim=feature_dim,
|
input_dim=feature_dim,
|
||||||
prototypes_per_class=prototypes_per_class,
|
prototypes_per_class=prototypes_per_class,
|
||||||
nclasses=num_classes,
|
num_classes=num_classes,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=prototype_data,
|
data=prototype_data,
|
||||||
)
|
)
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.initializers import get_initializer
|
from prototorch.functions.initializers import get_initializer
|
||||||
|
|
||||||
|
|
||||||
@ -53,13 +52,13 @@ class Prototypes1D(_Prototypes):
|
|||||||
raise NameError("`input_dim` required if "
|
raise NameError("`input_dim` required if "
|
||||||
"no `data` is provided.")
|
"no `data` is provided.")
|
||||||
if prototype_distribution:
|
if prototype_distribution:
|
||||||
kwargs_nclasses = sum(prototype_distribution)
|
kwargs_num_classes = sum(prototype_distribution)
|
||||||
else:
|
else:
|
||||||
if "nclasses" not in kwargs:
|
if "num_classes" not in kwargs:
|
||||||
raise NameError("`prototype_distribution` required if "
|
raise NameError("`prototype_distribution` required if "
|
||||||
"both `data` and `nclasses` are not "
|
"both `data` and `num_classes` are not "
|
||||||
"provided.")
|
"provided.")
|
||||||
kwargs_nclasses = kwargs.pop("nclasses")
|
kwargs_num_classes = kwargs.pop("num_classes")
|
||||||
input_dim = kwargs.pop("input_dim")
|
input_dim = kwargs.pop("input_dim")
|
||||||
if prototype_initializer in [
|
if prototype_initializer in [
|
||||||
"stratified_mean", "stratified_random"
|
"stratified_mean", "stratified_random"
|
||||||
@ -68,18 +67,18 @@ class Prototypes1D(_Prototypes):
|
|||||||
f"`prototype_initializer`: `{prototype_initializer}` "
|
f"`prototype_initializer`: `{prototype_initializer}` "
|
||||||
"requires `data`, but `data` is not provided. "
|
"requires `data`, but `data` is not provided. "
|
||||||
"Using randomly generated data instead.")
|
"Using randomly generated data instead.")
|
||||||
x_train = torch.rand(kwargs_nclasses, input_dim)
|
x_train = torch.rand(kwargs_num_classes, input_dim)
|
||||||
y_train = torch.arange(kwargs_nclasses)
|
y_train = torch.arange(kwargs_num_classes)
|
||||||
if one_hot_labels:
|
if one_hot_labels:
|
||||||
y_train = torch.eye(kwargs_nclasses)[y_train]
|
y_train = torch.eye(kwargs_num_classes)[y_train]
|
||||||
data = [x_train, y_train]
|
data = [x_train, y_train]
|
||||||
|
|
||||||
x_train, y_train = data
|
x_train, y_train = data
|
||||||
x_train = torch.as_tensor(x_train).type(dtype)
|
x_train = torch.as_tensor(x_train).type(dtype)
|
||||||
y_train = torch.as_tensor(y_train).type(torch.int)
|
y_train = torch.as_tensor(y_train).type(torch.int)
|
||||||
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
num_classes = torch.unique(y_train, dim=-1).shape[-1]
|
||||||
|
|
||||||
if nclasses == 1:
|
if num_classes == 1:
|
||||||
warnings.warn("Are you sure about having one class only?")
|
warnings.warn("Are you sure about having one class only?")
|
||||||
|
|
||||||
if x_train.ndim != 2:
|
if x_train.ndim != 2:
|
||||||
@ -105,19 +104,20 @@ class Prototypes1D(_Prototypes):
|
|||||||
"not match data dimension "
|
"not match data dimension "
|
||||||
f"`data[0].shape[1]`={x_train.shape[1]}")
|
f"`data[0].shape[1]`={x_train.shape[1]}")
|
||||||
|
|
||||||
# Verify the number of classes if `nclasses` is provided
|
# Verify the number of classes if `num_classes` is provided
|
||||||
if "nclasses" in kwargs:
|
if "num_classes" in kwargs:
|
||||||
kwargs_nclasses = kwargs.pop("nclasses")
|
kwargs_num_classes = kwargs.pop("num_classes")
|
||||||
if kwargs_nclasses != nclasses:
|
if kwargs_num_classes != num_classes:
|
||||||
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
|
raise ValueError(
|
||||||
"not match data labels "
|
f"Provided `num_classes={kwargs_num_classes}` does "
|
||||||
"`torch.unique(data[1]).shape[0]`"
|
"not match data labels "
|
||||||
f"={nclasses}")
|
"`torch.unique(data[1]).shape[0]`"
|
||||||
|
f"={num_classes}")
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
if not prototype_distribution:
|
if not prototype_distribution:
|
||||||
prototype_distribution = [prototypes_per_class] * nclasses
|
prototype_distribution = [prototypes_per_class] * num_classes
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.prototype_distribution = torch.tensor(prototype_distribution)
|
self.prototype_distribution = torch.tensor(prototype_distribution)
|
||||||
|
|
||||||
|
@ -4,7 +4,6 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.modules import losses, prototypes
|
from prototorch.modules import losses, prototypes
|
||||||
|
|
||||||
|
|
||||||
@ -18,20 +17,20 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
|
|
||||||
def test_prototypes1d_init_without_input_dim(self):
|
def test_prototypes1d_init_without_input_dim(self):
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = prototypes.Prototypes1D(nclasses=2)
|
_ = prototypes.Prototypes1D(num_classes=2)
|
||||||
|
|
||||||
def test_prototypes1d_init_without_nclasses(self):
|
def test_prototypes1d_init_without_num_classes(self):
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = prototypes.Prototypes1D(input_dim=1)
|
_ = prototypes.Prototypes1D(input_dim=1)
|
||||||
|
|
||||||
def test_prototypes1d_init_with_nclasses_1(self):
|
def test_prototypes1d_init_with_num_classes_1(self):
|
||||||
with self.assertWarns(UserWarning):
|
with self.assertWarns(UserWarning):
|
||||||
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
|
_ = prototypes.Prototypes1D(num_classes=1, input_dim=1)
|
||||||
|
|
||||||
def test_prototypes1d_init_without_pdist(self):
|
def test_prototypes1d_init_without_pdist(self):
|
||||||
p1 = prototypes.Prototypes1D(
|
p1 = prototypes.Prototypes1D(
|
||||||
input_dim=6,
|
input_dim=6,
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=4,
|
prototypes_per_class=4,
|
||||||
prototype_initializer="ones",
|
prototype_initializer="ones",
|
||||||
)
|
)
|
||||||
@ -60,7 +59,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
with self.assertWarns(UserWarning):
|
with self.assertWarns(UserWarning):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
input_dim=3,
|
input_dim=3,
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=None,
|
data=None,
|
||||||
@ -81,7 +80,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
|
|
||||||
def test_prototypes1d_init_without_inputdim_with_data(self):
|
def test_prototypes1d_init_without_inputdim_with_data(self):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1.0], [0.0]], [1, 0]],
|
data=[[[1.0], [0.0]], [1, 0]],
|
||||||
@ -89,7 +88,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
|
|
||||||
def test_prototypes1d_init_with_int_data(self):
|
def test_prototypes1d_init_with_int_data(self):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1], [0]], [1, 0]],
|
data=[[[1], [0]], [1, 0]],
|
||||||
@ -98,7 +97,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
def test_prototypes1d_init_one_hot_without_data(self):
|
def test_prototypes1d_init_one_hot_without_data(self):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=None,
|
data=None,
|
||||||
@ -112,7 +111,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
|
data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
|
||||||
@ -126,7 +125,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=([[0.0], [1.0]], [0, 1]),
|
data=([[0.0], [1.0]], [0, 1]),
|
||||||
@ -141,7 +140,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=([[0.0], [1.0]], [[0], [1]]),
|
data=([[0.0], [1.0]], [[0], [1]]),
|
||||||
@ -151,7 +150,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
def test_prototypes1d_init_with_int_dtype(self):
|
def test_prototypes1d_init_with_int_dtype(self):
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1], [0]], [1, 0]],
|
data=[[[1], [0]], [1, 0]],
|
||||||
@ -161,7 +160,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
def test_prototypes1d_inputndim_with_data(self):
|
def test_prototypes1d_inputndim_with_data(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(input_dim=1,
|
_ = prototypes.Prototypes1D(input_dim=1,
|
||||||
nclasses=1,
|
num_classes=1,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
data=[[1.0], [1]])
|
data=[[1.0], [1]])
|
||||||
|
|
||||||
@ -169,20 +168,20 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
input_dim=2,
|
input_dim=2,
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1.0], [0.0]], [1, 0]],
|
data=[[[1.0], [0.0]], [1, 0]],
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_prototypes1d_nclasses_with_data(self):
|
def test_prototypes1d_num_classes_with_data(self):
|
||||||
"""Test ValueError raise if provided `nclasses` is not the same
|
"""Test ValueError raise if provided `num_classes` is not the same
|
||||||
as the one computed from the provided `data`.
|
as the one computed from the provided `data`.
|
||||||
"""
|
"""
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=1,
|
num_classes=1,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer="stratified_mean",
|
prototype_initializer="stratified_mean",
|
||||||
data=[[[1.0], [2.0]], [1, 2]],
|
data=[[[1.0], [2.0]], [1, 2]],
|
||||||
@ -220,7 +219,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
|
|
||||||
p1 = prototypes.Prototypes1D(
|
p1 = prototypes.Prototypes1D(
|
||||||
input_dim=99,
|
input_dim=99,
|
||||||
nclasses=2,
|
num_classes=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer=my_initializer,
|
prototype_initializer=my_initializer,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user