diff --git a/prototorch/modules/prototypes.py b/prototorch/modules/prototypes.py index 081a7b8..73a273b 100644 --- a/prototorch/modules/prototypes.py +++ b/prototorch/modules/prototypes.py @@ -7,16 +7,65 @@ import torch from prototorch.functions.initializers import get_initializer -class Prototypes1D(torch.nn.Module): +class _Prototypes(torch.nn.Module): + """Abstract prototypes class.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _check_prototype_distribution(self): + if 0 in self.prototype_distribution: + warnings.warn('Are you sure about the 0 in ' + '`prototype_distribution`?') + + def extra_repr(self): + return f'prototypes.shape: {tuple(self.prototypes.shape)}' + + +class Prototypes1D(_Prototypes): + r"""Create a learnable set of one-dimensional prototypes. + + TODO Complete this doc-string + + Kwargs: + prototypes_per_class: number of prototypes to use per class. + Default: ``1`` + prototype_initializer: prototype initializer. + Default: ``'ones'`` + prototype_distribution: prototype distribution vector. + Default: ``None`` + input_dim: dimension of the incoming data. + nclasses: number of classes. + data: If set to ``None``, data-dependent initializers will be ignored. + Default: ``None`` + + Shape: + - Input: :math:`(N, H_{in})` + where :math:`H_{in} = \text{input_dim}`. + - Output: :math:`(N, H_{out})` + where :math:`H_{out} = \text{total_prototypes}`. + + Attributes: + prototypes: the learnable weights of the module of shape + :math:`(\text{total_prototypes}, \text{prototype_dimension})`. + prototype_labels: the non-learnable labels of the prototypes. + + Examples:: + + >>> p = Prototypes1D(input_dim=20, nclasses=10) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([20, 10]) + """ def __init__(self, prototypes_per_class=1, - prototype_distribution=None, prototype_initializer='ones', + prototype_distribution=None, data=None, dtype=torch.float32, **kwargs): - # Accept PyTorch tensors, but convert to python lists before processing + # Convert torch tensors to python lists before processing if torch.is_tensor(prototype_distribution): prototype_distribution = prototype_distribution.tolist() @@ -76,6 +125,8 @@ class Prototypes1D(torch.nn.Module): with torch.no_grad(): self.prototype_distribution = torch.tensor(prototype_distribution) + self._check_prototype_distribution() + self.prototype_initializer = get_initializer(prototype_initializer) prototypes, prototype_labels = self.prototype_initializer( x_train,