From 398d863232006103808f707036910f923f57774a Mon Sep 17 00:00:00 2001 From: blackfly Date: Mon, 6 Apr 2020 16:36:28 +0200 Subject: [PATCH] Add prototype and loss modules --- prototorch/modules/__init__.py | 0 prototorch/modules/losses.py | 21 ++++++++++++ prototorch/modules/prototypes.py | 57 ++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+) create mode 100644 prototorch/modules/__init__.py create mode 100644 prototorch/modules/losses.py create mode 100644 prototorch/modules/prototypes.py diff --git a/prototorch/modules/__init__.py b/prototorch/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/prototorch/modules/losses.py b/prototorch/modules/losses.py new file mode 100644 index 0000000..468c407 --- /dev/null +++ b/prototorch/modules/losses.py @@ -0,0 +1,21 @@ +"""ProtoTorch losses.""" + +import torch + +from prototorch.functions.activations import get_activation +from prototorch.functions.losses import glvq_loss + + +class GLVQLoss(torch.nn.Module): + """GLVQ Loss.""" + def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs): + super().__init__(**kwargs) + self.margin = margin + self.squashing = get_activation(squashing) + self.beta = beta + + def forward(self, outputs, targets): + distances, plabels = outputs + mu = glvq_loss(distances, targets, plabels) + batch_loss = self.squashing(mu + self.margin, beta=self.beta) + return torch.sum(batch_loss, dim=0) diff --git a/prototorch/modules/prototypes.py b/prototorch/modules/prototypes.py new file mode 100644 index 0000000..9b94d09 --- /dev/null +++ b/prototorch/modules/prototypes.py @@ -0,0 +1,57 @@ +"""ProtoTorch prototype modules.""" + +import torch + +from prototorch.functions.initializers import get_initializer + + +class AddPrototypes1D(torch.nn.Module): + def __init__(self, + prototypes_per_class=1, + prototype_distribution=None, + prototype_initializer='ones', + data=None, + **kwargs): + + if data is None: + if 'input_dim' not in kwargs: + raise NameError('`input_dim` required if ' + 'no `data` is provided.') + if prototype_distribution is not None: + nclasses = sum(prototype_distribution) + else: + if 'nclasses' not in kwargs: + raise NameError('`prototype_distribution` required if ' + 'both `data` and `nclasses` are not ' + 'provided.') + nclasses = kwargs.pop('nclasses') + input_dim = kwargs.pop('input_dim') + # input_shape = (input_dim, ) + x_train = torch.rand(nclasses, input_dim) + y_train = torch.arange(nclasses) + + else: + x_train, y_train = data + x_train = torch.as_tensor(x_train) + y_train = torch.as_tensor(y_train) + + super().__init__(**kwargs) + self.prototypes_per_class = prototypes_per_class + with torch.no_grad(): + if not prototype_distribution: + num_classes = torch.unique(y_train).shape[0] + self.prototype_distribution = torch.tensor( + [self.prototypes_per_class] * num_classes) + else: + self.prototype_distribution = torch.tensor( + prototype_distribution) + self.prototype_initializer = get_initializer(prototype_initializer) + prototypes, prototype_labels = self.prototype_initializer( + x_train, + y_train, + prototype_distribution=self.prototype_distribution) + self.prototypes = torch.nn.Parameter(prototypes) + self.prototype_labels = prototype_labels + + def forward(self): + return self.prototypes, self.prototype_labels