Add prototype and loss modules
This commit is contained in:
parent
8a96749716
commit
398d863232
0
prototorch/modules/__init__.py
Normal file
0
prototorch/modules/__init__.py
Normal file
21
prototorch/modules/losses.py
Normal file
21
prototorch/modules/losses.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""ProtoTorch losses."""
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.losses import glvq_loss
|
||||
|
||||
|
||||
class GLVQLoss(torch.nn.Module):
|
||||
"""GLVQ Loss."""
|
||||
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
self.beta = beta
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
mu = glvq_loss(distances, targets, plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
57
prototorch/modules/prototypes.py
Normal file
57
prototorch/modules/prototypes.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""ProtoTorch prototype modules."""
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
|
||||
|
||||
class AddPrototypes1D(torch.nn.Module):
|
||||
def __init__(self,
|
||||
prototypes_per_class=1,
|
||||
prototype_distribution=None,
|
||||
prototype_initializer='ones',
|
||||
data=None,
|
||||
**kwargs):
|
||||
|
||||
if data is None:
|
||||
if 'input_dim' not in kwargs:
|
||||
raise NameError('`input_dim` required if '
|
||||
'no `data` is provided.')
|
||||
if prototype_distribution is not None:
|
||||
nclasses = sum(prototype_distribution)
|
||||
else:
|
||||
if 'nclasses' not in kwargs:
|
||||
raise NameError('`prototype_distribution` required if '
|
||||
'both `data` and `nclasses` are not '
|
||||
'provided.')
|
||||
nclasses = kwargs.pop('nclasses')
|
||||
input_dim = kwargs.pop('input_dim')
|
||||
# input_shape = (input_dim, )
|
||||
x_train = torch.rand(nclasses, input_dim)
|
||||
y_train = torch.arange(nclasses)
|
||||
|
||||
else:
|
||||
x_train, y_train = data
|
||||
x_train = torch.as_tensor(x_train)
|
||||
y_train = torch.as_tensor(y_train)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.prototypes_per_class = prototypes_per_class
|
||||
with torch.no_grad():
|
||||
if not prototype_distribution:
|
||||
num_classes = torch.unique(y_train).shape[0]
|
||||
self.prototype_distribution = torch.tensor(
|
||||
[self.prototypes_per_class] * num_classes)
|
||||
else:
|
||||
self.prototype_distribution = torch.tensor(
|
||||
prototype_distribution)
|
||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
||||
prototypes, prototype_labels = self.prototype_initializer(
|
||||
x_train,
|
||||
y_train,
|
||||
prototype_distribution=self.prototype_distribution)
|
||||
self.prototypes = torch.nn.Parameter(prototypes)
|
||||
self.prototype_labels = prototype_labels
|
||||
|
||||
def forward(self):
|
||||
return self.prototypes, self.prototype_labels
|
Loading…
Reference in New Issue
Block a user