Add LVQ1 and LVQ2.1 Models.

This commit is contained in:
Alexander Engelsberger
2021-05-11 13:26:13 +02:00
parent 30ee287ecc
commit 3fa6378c4d
3 changed files with 85 additions and 3 deletions

View File

@@ -5,10 +5,12 @@ from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from .abstract import AbstractPrototypeModel
from torch.optim.lr_scheduler import ExponentialLR
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
@@ -30,6 +32,8 @@ class GLVQ(AbstractPrototypeModel):
self.transfer_function = get_activation(self.hparams.transfer_function)
self.train_acc = torchmetrics.Accuracy()
self.loss = glvq_loss
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
@@ -44,7 +48,7 @@ class GLVQ(AbstractPrototypeModel):
x = x.view(x.size(0), -1) # flatten
dis = self(x)
plabels = self.proto_layer.component_labels
mu = glvq_loss(dis, y, prototype_labels=plabels)
mu = self.loss(dis, y, prototype_labels=plabels)
batch_loss = self.transfer_function(mu,
beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
@@ -76,6 +80,42 @@ class GLVQ(AbstractPrototypeModel):
return y_pred.numpy()
class LVQ1(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq1_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class LVQ21(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq21_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class ImageGLVQ(GLVQ):
"""GLVQ for training on image data.