Refactor prototypes module and begin documentation
This commit is contained in:
parent
cf7d7b5d9d
commit
a9d2855323
@ -7,16 +7,65 @@ import torch
|
|||||||
from prototorch.functions.initializers import get_initializer
|
from prototorch.functions.initializers import get_initializer
|
||||||
|
|
||||||
|
|
||||||
class Prototypes1D(torch.nn.Module):
|
class _Prototypes(torch.nn.Module):
|
||||||
|
"""Abstract prototypes class."""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _check_prototype_distribution(self):
|
||||||
|
if 0 in self.prototype_distribution:
|
||||||
|
warnings.warn('Are you sure about the 0 in '
|
||||||
|
'`prototype_distribution`?')
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
|
||||||
|
|
||||||
|
|
||||||
|
class Prototypes1D(_Prototypes):
|
||||||
|
r"""Create a learnable set of one-dimensional prototypes.
|
||||||
|
|
||||||
|
TODO Complete this doc-string
|
||||||
|
|
||||||
|
Kwargs:
|
||||||
|
prototypes_per_class: number of prototypes to use per class.
|
||||||
|
Default: ``1``
|
||||||
|
prototype_initializer: prototype initializer.
|
||||||
|
Default: ``'ones'``
|
||||||
|
prototype_distribution: prototype distribution vector.
|
||||||
|
Default: ``None``
|
||||||
|
input_dim: dimension of the incoming data.
|
||||||
|
nclasses: number of classes.
|
||||||
|
data: If set to ``None``, data-dependent initializers will be ignored.
|
||||||
|
Default: ``None``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, H_{in})`
|
||||||
|
where :math:`H_{in} = \text{input_dim}`.
|
||||||
|
- Output: :math:`(N, H_{out})`
|
||||||
|
where :math:`H_{out} = \text{total_prototypes}`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prototypes: the learnable weights of the module of shape
|
||||||
|
:math:`(\text{total_prototypes}, \text{prototype_dimension})`.
|
||||||
|
prototype_labels: the non-learnable labels of the prototypes.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> p = Prototypes1D(input_dim=20, nclasses=10)
|
||||||
|
>>> input = torch.randn(128, 20)
|
||||||
|
>>> output = m(input)
|
||||||
|
>>> print(output.size())
|
||||||
|
torch.Size([20, 10])
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_distribution=None,
|
|
||||||
prototype_initializer='ones',
|
prototype_initializer='ones',
|
||||||
|
prototype_distribution=None,
|
||||||
data=None,
|
data=None,
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
# Accept PyTorch tensors, but convert to python lists before processing
|
# Convert torch tensors to python lists before processing
|
||||||
if torch.is_tensor(prototype_distribution):
|
if torch.is_tensor(prototype_distribution):
|
||||||
prototype_distribution = prototype_distribution.tolist()
|
prototype_distribution = prototype_distribution.tolist()
|
||||||
|
|
||||||
@ -76,6 +125,8 @@ class Prototypes1D(torch.nn.Module):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.prototype_distribution = torch.tensor(prototype_distribution)
|
self.prototype_distribution = torch.tensor(prototype_distribution)
|
||||||
|
|
||||||
|
self._check_prototype_distribution()
|
||||||
|
|
||||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
self.prototype_initializer = get_initializer(prototype_initializer)
|
||||||
prototypes, prototype_labels = self.prototype_initializer(
|
prototypes, prototype_labels = self.prototype_initializer(
|
||||||
x_train,
|
x_train,
|
||||||
|
Loading…
Reference in New Issue
Block a user