Add LVQ1 and LVQ2.1 Models.

This commit is contained in:
Alexander Engelsberger
2021-05-11 13:26:13 +02:00
parent 30ee287ecc
commit 3fa6378c4d
3 changed files with 85 additions and 3 deletions

View File

@@ -1,7 +1,7 @@
from importlib.metadata import PackageNotFoundError, version
from .cbc import CBC
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ, LVQ1, LVQ21
from .neural_gas import NeuralGas
from .vis import *

View File

@@ -5,10 +5,12 @@ from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
squared_euclidean_distance)
from prototorch.functions.losses import glvq_loss
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from .abstract import AbstractPrototypeModel
from torch.optim.lr_scheduler import ExponentialLR
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
@@ -30,6 +32,8 @@ class GLVQ(AbstractPrototypeModel):
self.transfer_function = get_activation(self.hparams.transfer_function)
self.train_acc = torchmetrics.Accuracy()
self.loss = glvq_loss
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
@@ -44,7 +48,7 @@ class GLVQ(AbstractPrototypeModel):
x = x.view(x.size(0), -1) # flatten
dis = self(x)
plabels = self.proto_layer.component_labels
mu = glvq_loss(dis, y, prototype_labels=plabels)
mu = self.loss(dis, y, prototype_labels=plabels)
batch_loss = self.transfer_function(mu,
beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
@@ -76,6 +80,42 @@ class GLVQ(AbstractPrototypeModel):
return y_pred.numpy()
class LVQ1(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq1_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class LVQ21(GLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq21_loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr)
scheduler = ExponentialLR(optimizer,
gamma=0.99,
last_epoch=-1,
verbose=False)
sch = {
"scheduler": scheduler,
"interval": "step",
} # called after each training step
return [optimizer], [sch]
class ImageGLVQ(GLVQ):
"""GLVQ for training on image data.