Use 'num_' in all variable names
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from prototorch.functions.distances import (euclidean_distance_matrix,
|
||||
tangent_distance)
|
||||
from prototorch.functions.helper import _check_shapes, _int_and_mixed_shape
|
||||
from prototorch.functions.normalization import orthogonalization
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
from torch import nn
|
||||
|
||||
|
||||
class GTLVQ(nn.Module):
|
||||
@@ -99,7 +98,7 @@ class GTLVQ(nn.Module):
|
||||
self.cls = Prototypes1D(
|
||||
input_dim=feature_dim,
|
||||
prototypes_per_class=prototypes_per_class,
|
||||
nclasses=num_classes,
|
||||
num_classes=num_classes,
|
||||
prototype_initializer="stratified_mean",
|
||||
data=prototype_data,
|
||||
)
|
||||
|
@@ -3,7 +3,6 @@
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
|
||||
|
||||
@@ -53,13 +52,13 @@ class Prototypes1D(_Prototypes):
|
||||
raise NameError("`input_dim` required if "
|
||||
"no `data` is provided.")
|
||||
if prototype_distribution:
|
||||
kwargs_nclasses = sum(prototype_distribution)
|
||||
kwargs_num_classes = sum(prototype_distribution)
|
||||
else:
|
||||
if "nclasses" not in kwargs:
|
||||
if "num_classes" not in kwargs:
|
||||
raise NameError("`prototype_distribution` required if "
|
||||
"both `data` and `nclasses` are not "
|
||||
"both `data` and `num_classes` are not "
|
||||
"provided.")
|
||||
kwargs_nclasses = kwargs.pop("nclasses")
|
||||
kwargs_num_classes = kwargs.pop("num_classes")
|
||||
input_dim = kwargs.pop("input_dim")
|
||||
if prototype_initializer in [
|
||||
"stratified_mean", "stratified_random"
|
||||
@@ -68,18 +67,18 @@ class Prototypes1D(_Prototypes):
|
||||
f"`prototype_initializer`: `{prototype_initializer}` "
|
||||
"requires `data`, but `data` is not provided. "
|
||||
"Using randomly generated data instead.")
|
||||
x_train = torch.rand(kwargs_nclasses, input_dim)
|
||||
y_train = torch.arange(kwargs_nclasses)
|
||||
x_train = torch.rand(kwargs_num_classes, input_dim)
|
||||
y_train = torch.arange(kwargs_num_classes)
|
||||
if one_hot_labels:
|
||||
y_train = torch.eye(kwargs_nclasses)[y_train]
|
||||
y_train = torch.eye(kwargs_num_classes)[y_train]
|
||||
data = [x_train, y_train]
|
||||
|
||||
x_train, y_train = data
|
||||
x_train = torch.as_tensor(x_train).type(dtype)
|
||||
y_train = torch.as_tensor(y_train).type(torch.int)
|
||||
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
||||
num_classes = torch.unique(y_train, dim=-1).shape[-1]
|
||||
|
||||
if nclasses == 1:
|
||||
if num_classes == 1:
|
||||
warnings.warn("Are you sure about having one class only?")
|
||||
|
||||
if x_train.ndim != 2:
|
||||
@@ -105,19 +104,20 @@ class Prototypes1D(_Prototypes):
|
||||
"not match data dimension "
|
||||
f"`data[0].shape[1]`={x_train.shape[1]}")
|
||||
|
||||
# Verify the number of classes if `nclasses` is provided
|
||||
if "nclasses" in kwargs:
|
||||
kwargs_nclasses = kwargs.pop("nclasses")
|
||||
if kwargs_nclasses != nclasses:
|
||||
raise ValueError(f"Provided `nclasses={kwargs_nclasses}` does "
|
||||
"not match data labels "
|
||||
"`torch.unique(data[1]).shape[0]`"
|
||||
f"={nclasses}")
|
||||
# Verify the number of classes if `num_classes` is provided
|
||||
if "num_classes" in kwargs:
|
||||
kwargs_num_classes = kwargs.pop("num_classes")
|
||||
if kwargs_num_classes != num_classes:
|
||||
raise ValueError(
|
||||
f"Provided `num_classes={kwargs_num_classes}` does "
|
||||
"not match data labels "
|
||||
"`torch.unique(data[1]).shape[0]`"
|
||||
f"={num_classes}")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not prototype_distribution:
|
||||
prototype_distribution = [prototypes_per_class] * nclasses
|
||||
prototype_distribution = [prototypes_per_class] * num_classes
|
||||
with torch.no_grad():
|
||||
self.prototype_distribution = torch.tensor(prototype_distribution)
|
||||
|
||||
|
Reference in New Issue
Block a user