diff --git a/prototorch/modules/prototypes.py b/prototorch/modules/prototypes.py index 73a273b..191cec5 100644 --- a/prototorch/modules/prototypes.py +++ b/prototorch/modules/prototypes.py @@ -2,8 +2,11 @@ import warnings +import numpy as np import torch +from prototorch.functions.competitions import wtac +from prototorch.functions.distances import sed from prototorch.functions.initializers import get_initializer @@ -12,14 +15,17 @@ class _Prototypes(torch.nn.Module): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def _check_prototype_distribution(self): + def _validate_prototype_distribution(self): if 0 in self.prototype_distribution: - warnings.warn('Are you sure about the 0 in ' + warnings.warn('Are you sure about the `0` in ' '`prototype_distribution`?') def extra_repr(self): return f'prototypes.shape: {tuple(self.prototypes.shape)}' + def forward(self): + return self.prototypes, self.prototype_labels + class Prototypes1D(_Prototypes): r"""Create a learnable set of one-dimensional prototypes. @@ -63,24 +69,26 @@ class Prototypes1D(_Prototypes): prototype_distribution=None, data=None, dtype=torch.float32, + one_hot_labels=False, **kwargs): - # Convert torch tensors to python lists before processing - if torch.is_tensor(prototype_distribution): - prototype_distribution = prototype_distribution.tolist() + # Convert tensors to python lists before processing + if prototype_distribution is not None: + if not isinstance(prototype_distribution, list): + prototype_distribution = prototype_distribution.tolist() if data is None: if 'input_dim' not in kwargs: raise NameError('`input_dim` required if ' 'no `data` is provided.') if prototype_distribution: - nclasses = sum(prototype_distribution) + kwargs_nclasses = sum(prototype_distribution) else: if 'nclasses' not in kwargs: raise NameError('`prototype_distribution` required if ' 'both `data` and `nclasses` are not ' 'provided.') - nclasses = kwargs.pop('nclasses') + kwargs_nclasses = kwargs.pop('nclasses') input_dim = kwargs.pop('input_dim') if prototype_initializer in [ 'stratified_mean', 'stratified_random' @@ -89,18 +97,35 @@ class Prototypes1D(_Prototypes): f'`prototype_initializer`: `{prototype_initializer}` ' 'requires `data`, but `data` is not provided. ' 'Using randomly generated data instead.') - x_train = torch.rand(nclasses, input_dim) - y_train = torch.arange(nclasses) + x_train = torch.rand(kwargs_nclasses, input_dim) + y_train = torch.arange(kwargs_nclasses) + if one_hot_labels: + y_train = torch.eye(kwargs_nclasses)[y_train] data = [x_train, y_train] x_train, y_train = data x_train = torch.as_tensor(x_train).type(dtype) - y_train = torch.as_tensor(y_train).type(dtype) - nclasses = torch.unique(y_train).shape[0] + y_train = torch.as_tensor(y_train).type(torch.int) + nclasses = torch.unique(y_train, dim=-1).shape[-1] + + if nclasses == 1: + warnings.warn('Are you sure about having one class only?') if x_train.ndim != 2: raise ValueError('`data[0].ndim != 2`.') + if y_train.ndim == 2: + if y_train.shape[1] == 1 and one_hot_labels: + raise ValueError('`one_hot_labels` is set to `True` ' + 'but target labels are not one-hot-encoded.') + if y_train.shape[1] != 1 and not one_hot_labels: + raise ValueError('`one_hot_labels` is set to `False` ' + 'but target labels in `data` ' + 'are one-hot-encoded.') + if y_train.ndim == 1 and one_hot_labels: + raise ValueError('`one_hot_labels` is set to `True` ' + 'but target labels are not one-hot-encoded.') + # Verify input dimension if `input_dim` is provided if 'input_dim' in kwargs: input_dim = kwargs.pop('input_dim') @@ -125,17 +150,16 @@ class Prototypes1D(_Prototypes): with torch.no_grad(): self.prototype_distribution = torch.tensor(prototype_distribution) - self._check_prototype_distribution() + self._validate_prototype_distribution() self.prototype_initializer = get_initializer(prototype_initializer) prototypes, prototype_labels = self.prototype_initializer( x_train, y_train, - prototype_distribution=self.prototype_distribution) + prototype_distribution=self.prototype_distribution, + one_hot=one_hot_labels, + ) # Register module parameters self.prototypes = torch.nn.Parameter(prototypes) self.prototype_labels = prototype_labels - - def forward(self): - return self.prototypes, self.prototype_labels