Refactor prototypes module and begin documentation

This commit is contained in:
blackfly 2020-04-14 19:48:46 +02:00
parent cf7d7b5d9d
commit a9d2855323

View File

@ -7,16 +7,65 @@ import torch
from prototorch.functions.initializers import get_initializer from prototorch.functions.initializers import get_initializer
class Prototypes1D(torch.nn.Module): class _Prototypes(torch.nn.Module):
"""Abstract prototypes class."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _check_prototype_distribution(self):
if 0 in self.prototype_distribution:
warnings.warn('Are you sure about the 0 in '
'`prototype_distribution`?')
def extra_repr(self):
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
class Prototypes1D(_Prototypes):
r"""Create a learnable set of one-dimensional prototypes.
TODO Complete this doc-string
Kwargs:
prototypes_per_class: number of prototypes to use per class.
Default: ``1``
prototype_initializer: prototype initializer.
Default: ``'ones'``
prototype_distribution: prototype distribution vector.
Default: ``None``
input_dim: dimension of the incoming data.
nclasses: number of classes.
data: If set to ``None``, data-dependent initializers will be ignored.
Default: ``None``
Shape:
- Input: :math:`(N, H_{in})`
where :math:`H_{in} = \text{input_dim}`.
- Output: :math:`(N, H_{out})`
where :math:`H_{out} = \text{total_prototypes}`.
Attributes:
prototypes: the learnable weights of the module of shape
:math:`(\text{total_prototypes}, \text{prototype_dimension})`.
prototype_labels: the non-learnable labels of the prototypes.
Examples::
>>> p = Prototypes1D(input_dim=20, nclasses=10)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([20, 10])
"""
def __init__(self, def __init__(self,
prototypes_per_class=1, prototypes_per_class=1,
prototype_distribution=None,
prototype_initializer='ones', prototype_initializer='ones',
prototype_distribution=None,
data=None, data=None,
dtype=torch.float32, dtype=torch.float32,
**kwargs): **kwargs):
# Accept PyTorch tensors, but convert to python lists before processing # Convert torch tensors to python lists before processing
if torch.is_tensor(prototype_distribution): if torch.is_tensor(prototype_distribution):
prototype_distribution = prototype_distribution.tolist() prototype_distribution = prototype_distribution.tolist()
@ -76,6 +125,8 @@ class Prototypes1D(torch.nn.Module):
with torch.no_grad(): with torch.no_grad():
self.prototype_distribution = torch.tensor(prototype_distribution) self.prototype_distribution = torch.tensor(prototype_distribution)
self._check_prototype_distribution()
self.prototype_initializer = get_initializer(prototype_initializer) self.prototype_initializer = get_initializer(prototype_initializer)
prototypes, prototype_labels = self.prototype_initializer( prototypes, prototype_labels = self.prototype_initializer(
x_train, x_train,