Add prototype and loss modules
This commit is contained in:
parent
8a96749716
commit
398d863232
0
prototorch/modules/__init__.py
Normal file
0
prototorch/modules/__init__.py
Normal file
21
prototorch/modules/losses.py
Normal file
21
prototorch/modules/losses.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
"""ProtoTorch losses."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from prototorch.functions.activations import get_activation
|
||||||
|
from prototorch.functions.losses import glvq_loss
|
||||||
|
|
||||||
|
|
||||||
|
class GLVQLoss(torch.nn.Module):
|
||||||
|
"""GLVQ Loss."""
|
||||||
|
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.margin = margin
|
||||||
|
self.squashing = get_activation(squashing)
|
||||||
|
self.beta = beta
|
||||||
|
|
||||||
|
def forward(self, outputs, targets):
|
||||||
|
distances, plabels = outputs
|
||||||
|
mu = glvq_loss(distances, targets, plabels)
|
||||||
|
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||||
|
return torch.sum(batch_loss, dim=0)
|
57
prototorch/modules/prototypes.py
Normal file
57
prototorch/modules/prototypes.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""ProtoTorch prototype modules."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from prototorch.functions.initializers import get_initializer
|
||||||
|
|
||||||
|
|
||||||
|
class AddPrototypes1D(torch.nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_distribution=None,
|
||||||
|
prototype_initializer='ones',
|
||||||
|
data=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
if 'input_dim' not in kwargs:
|
||||||
|
raise NameError('`input_dim` required if '
|
||||||
|
'no `data` is provided.')
|
||||||
|
if prototype_distribution is not None:
|
||||||
|
nclasses = sum(prototype_distribution)
|
||||||
|
else:
|
||||||
|
if 'nclasses' not in kwargs:
|
||||||
|
raise NameError('`prototype_distribution` required if '
|
||||||
|
'both `data` and `nclasses` are not '
|
||||||
|
'provided.')
|
||||||
|
nclasses = kwargs.pop('nclasses')
|
||||||
|
input_dim = kwargs.pop('input_dim')
|
||||||
|
# input_shape = (input_dim, )
|
||||||
|
x_train = torch.rand(nclasses, input_dim)
|
||||||
|
y_train = torch.arange(nclasses)
|
||||||
|
|
||||||
|
else:
|
||||||
|
x_train, y_train = data
|
||||||
|
x_train = torch.as_tensor(x_train)
|
||||||
|
y_train = torch.as_tensor(y_train)
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.prototypes_per_class = prototypes_per_class
|
||||||
|
with torch.no_grad():
|
||||||
|
if not prototype_distribution:
|
||||||
|
num_classes = torch.unique(y_train).shape[0]
|
||||||
|
self.prototype_distribution = torch.tensor(
|
||||||
|
[self.prototypes_per_class] * num_classes)
|
||||||
|
else:
|
||||||
|
self.prototype_distribution = torch.tensor(
|
||||||
|
prototype_distribution)
|
||||||
|
self.prototype_initializer = get_initializer(prototype_initializer)
|
||||||
|
prototypes, prototype_labels = self.prototype_initializer(
|
||||||
|
x_train,
|
||||||
|
y_train,
|
||||||
|
prototype_distribution=self.prototype_distribution)
|
||||||
|
self.prototypes = torch.nn.Parameter(prototypes)
|
||||||
|
self.prototype_labels = prototype_labels
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.prototypes, self.prototype_labels
|
Loading…
Reference in New Issue
Block a user