Add prototype and loss modules

This commit is contained in:
blackfly 2020-04-06 16:36:28 +02:00
parent 8a96749716
commit 398d863232
3 changed files with 78 additions and 0 deletions

View File

View File

@ -0,0 +1,21 @@
"""ProtoTorch losses."""
import torch
from prototorch.functions.activations import get_activation
from prototorch.functions.losses import glvq_loss
class GLVQLoss(torch.nn.Module):
"""GLVQ Loss."""
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
super().__init__(**kwargs)
self.margin = margin
self.squashing = get_activation(squashing)
self.beta = beta
def forward(self, outputs, targets):
distances, plabels = outputs
mu = glvq_loss(distances, targets, plabels)
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
return torch.sum(batch_loss, dim=0)

View File

@ -0,0 +1,57 @@
"""ProtoTorch prototype modules."""
import torch
from prototorch.functions.initializers import get_initializer
class AddPrototypes1D(torch.nn.Module):
def __init__(self,
prototypes_per_class=1,
prototype_distribution=None,
prototype_initializer='ones',
data=None,
**kwargs):
if data is None:
if 'input_dim' not in kwargs:
raise NameError('`input_dim` required if '
'no `data` is provided.')
if prototype_distribution is not None:
nclasses = sum(prototype_distribution)
else:
if 'nclasses' not in kwargs:
raise NameError('`prototype_distribution` required if '
'both `data` and `nclasses` are not '
'provided.')
nclasses = kwargs.pop('nclasses')
input_dim = kwargs.pop('input_dim')
# input_shape = (input_dim, )
x_train = torch.rand(nclasses, input_dim)
y_train = torch.arange(nclasses)
else:
x_train, y_train = data
x_train = torch.as_tensor(x_train)
y_train = torch.as_tensor(y_train)
super().__init__(**kwargs)
self.prototypes_per_class = prototypes_per_class
with torch.no_grad():
if not prototype_distribution:
num_classes = torch.unique(y_train).shape[0]
self.prototype_distribution = torch.tensor(
[self.prototypes_per_class] * num_classes)
else:
self.prototype_distribution = torch.tensor(
prototype_distribution)
self.prototype_initializer = get_initializer(prototype_initializer)
prototypes, prototype_labels = self.prototype_initializer(
x_train,
y_train,
prototype_distribution=self.prototype_distribution)
self.prototypes = torch.nn.Parameter(prototypes)
self.prototype_labels = prototype_labels
def forward(self):
return self.prototypes, self.prototype_labels