Use 'num_' in all variable names

This commit is contained in:
Alexander Engelsberger
2021-05-25 15:57:05 +02:00
parent aff7a385a3
commit 73e6fe384e
11 changed files with 84 additions and 89 deletions

View File

@@ -1,11 +1,10 @@
import torch
from torch import nn
from prototorch.functions.distances import (euclidean_distance_matrix,
tangent_distance)
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
from prototorch.functions.normalization import orthogonalization
from prototorch.modules.prototypes import Prototypes1D
from torch import nn
class GTLVQ(nn.Module):
@@ -99,7 +98,7 @@ class GTLVQ(nn.Module):
self.cls = Prototypes1D(
input_dim=feature_dim,
prototypes_per_class=prototypes_per_class,
nclasses=num_classes,
num_classes=num_classes,
prototype_initializer="stratified_mean",
data=prototype_data,
)

View File

@@ -3,7 +3,6 @@
import warnings
import torch
from prototorch.functions.initializers import get_initializer
@@ -53,13 +52,13 @@ class Prototypes1D(_Prototypes):
raise NameError("`input_dim` required if "
"no `data` is provided.")
if prototype_distribution:
kwargs_nclasses = sum(prototype_distribution)
kwargs_num_classes = sum(prototype_distribution)
else:
if "nclasses" not in kwargs:
if "num_classes" not in kwargs:
raise NameError("`prototype_distribution` required if "
"both `data` and `nclasses` are not "
"both `data` and `num_classes` are not "
"provided.")
kwargs_nclasses = kwargs.pop("nclasses")
kwargs_num_classes = kwargs.pop("num_classes")
input_dim = kwargs.pop("input_dim")
if prototype_initializer in [
"stratified_mean", "stratified_random"
@@ -68,18 +67,18 @@ class Prototypes1D(_Prototypes):
f"`prototype_initializer`: `{prototype_initializer}` "
"requires `data`, but `data` is not provided. "
"Using randomly generated data instead.")
x_train = torch.rand(kwargs_nclasses, input_dim)
y_train = torch.arange(kwargs_nclasses)
x_train = torch.rand(kwargs_num_classes, input_dim)
y_train = torch.arange(kwargs_num_classes)
if one_hot_labels:
y_train = torch.eye(kwargs_nclasses)[y_train]
y_train = torch.eye(kwargs_num_classes)[y_train]
data = [x_train, y_train]
x_train, y_train = data
x_train = torch.as_tensor(x_train).type(dtype)
y_train = torch.as_tensor(y_train).type(torch.int)
nclasses = torch.unique(y_train, dim=-1).shape[-1]
num_classes = torch.unique(y_train, dim=-1).shape[-1]
if nclasses == 1:
if num_classes == 1:
warnings.warn("Are you sure about having one class only?")
if x_train.ndim != 2:
@@ -105,19 +104,20 @@ class Prototypes1D(_Prototypes):
"not match data dimension "
f"`data[0].shape[1]`={x_train.shape[1]}")
# Verify the number of classes if `nclasses` is provided
if "nclasses" in kwargs:
kwargs_nclasses = kwargs.pop("nclasses")
if kwargs_nclasses != nclasses:
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
"not match data labels "
"`torch.unique(data[1]).shape[0]`"
f"={nclasses}")
# Verify the number of classes if `num_classes` is provided
if "num_classes" in kwargs:
kwargs_num_classes = kwargs.pop("num_classes")
if kwargs_num_classes != num_classes:
raise ValueError(
f"Provided `num_classes={kwargs_num_classes}` does "
"not match data labels "
"`torch.unique(data[1]).shape[0]`"
f"={num_classes}")
super().__init__(**kwargs)
if not prototype_distribution:
prototype_distribution = [prototypes_per_class] * nclasses
prototype_distribution = [prototypes_per_class] * num_classes
with torch.no_grad():
self.prototype_distribution = torch.tensor(prototype_distribution)