Refactor prototypes module and begin documentation
This commit is contained in:
parent
cf7d7b5d9d
commit
a9d2855323
@ -7,16 +7,65 @@ import torch
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
|
||||
|
||||
class Prototypes1D(torch.nn.Module):
|
||||
class _Prototypes(torch.nn.Module):
|
||||
"""Abstract prototypes class."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _check_prototype_distribution(self):
|
||||
if 0 in self.prototype_distribution:
|
||||
warnings.warn('Are you sure about the 0 in '
|
||||
'`prototype_distribution`?')
|
||||
|
||||
def extra_repr(self):
|
||||
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
|
||||
|
||||
|
||||
class Prototypes1D(_Prototypes):
|
||||
r"""Create a learnable set of one-dimensional prototypes.
|
||||
|
||||
TODO Complete this doc-string
|
||||
|
||||
Kwargs:
|
||||
prototypes_per_class: number of prototypes to use per class.
|
||||
Default: ``1``
|
||||
prototype_initializer: prototype initializer.
|
||||
Default: ``'ones'``
|
||||
prototype_distribution: prototype distribution vector.
|
||||
Default: ``None``
|
||||
input_dim: dimension of the incoming data.
|
||||
nclasses: number of classes.
|
||||
data: If set to ``None``, data-dependent initializers will be ignored.
|
||||
Default: ``None``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, H_{in})`
|
||||
where :math:`H_{in} = \text{input_dim}`.
|
||||
- Output: :math:`(N, H_{out})`
|
||||
where :math:`H_{out} = \text{total_prototypes}`.
|
||||
|
||||
Attributes:
|
||||
prototypes: the learnable weights of the module of shape
|
||||
:math:`(\text{total_prototypes}, \text{prototype_dimension})`.
|
||||
prototype_labels: the non-learnable labels of the prototypes.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> p = Prototypes1D(input_dim=20, nclasses=10)
|
||||
>>> input = torch.randn(128, 20)
|
||||
>>> output = m(input)
|
||||
>>> print(output.size())
|
||||
torch.Size([20, 10])
|
||||
"""
|
||||
def __init__(self,
|
||||
prototypes_per_class=1,
|
||||
prototype_distribution=None,
|
||||
prototype_initializer='ones',
|
||||
prototype_distribution=None,
|
||||
data=None,
|
||||
dtype=torch.float32,
|
||||
**kwargs):
|
||||
|
||||
# Accept PyTorch tensors, but convert to python lists before processing
|
||||
# Convert torch tensors to python lists before processing
|
||||
if torch.is_tensor(prototype_distribution):
|
||||
prototype_distribution = prototype_distribution.tolist()
|
||||
|
||||
@ -76,6 +125,8 @@ class Prototypes1D(torch.nn.Module):
|
||||
with torch.no_grad():
|
||||
self.prototype_distribution = torch.tensor(prototype_distribution)
|
||||
|
||||
self._check_prototype_distribution()
|
||||
|
||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
||||
prototypes, prototype_labels = self.prototype_initializer(
|
||||
x_train,
|
||||
|
Loading…
Reference in New Issue
Block a user