This commit is contained in:
Jensun Ravichandran
2021-04-14 19:20:08 +02:00
parent 6796ec494f
commit 98a8fc52fa
12 changed files with 333 additions and 56 deletions

View File

@@ -25,40 +25,9 @@ class _Prototypes(torch.nn.Module):
class Prototypes1D(_Prototypes):
r"""Create a learnable set of one-dimensional prototypes.
"""Create a learnable set of one-dimensional prototypes.
TODO Complete this doc-string
Kwargs:
prototypes_per_class: number of prototypes to use per class.
Default: ``1``
prototype_initializer: prototype initializer.
Default: ``'ones'``
prototype_distribution: prototype distribution vector.
Default: ``None``
input_dim: dimension of the incoming data.
nclasses: number of classes.
data: If set to ``None``, data-dependent initializers will be ignored.
Default: ``None``
Shape:
- Input: :math:`(N, H_{in})`
where :math:`H_{in} = \text{input_dim}`.
- Output: :math:`(N, H_{out})`
where :math:`H_{out} = \text{total_prototypes}`.
Attributes:
prototypes: the learnable weights of the module of shape
:math:`(\text{total_prototypes}, \text{prototype_dimension})`.
prototype_labels: the non-learnable labels of the prototypes.
Examples:
>>> p = Prototypes1D(input_dim=20, nclasses=10)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([20, 10])
TODO Complete this doc-string.
"""
def __init__(self,
prototypes_per_class=1,