Add glvq model
This commit is contained in:
parent
3687abc3a7
commit
f6994dfd83
57
prototorch/models/glvq.py
Normal file
57
prototorch/models/glvq.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from prototorch.functions.competitions import wtac
|
||||||
|
from prototorch.functions.distances import euclidean_distance
|
||||||
|
from prototorch.functions.initializers import get_initializer
|
||||||
|
from prototorch.functions.losses import glvq_loss
|
||||||
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
|
||||||
|
|
||||||
|
class GLVQ(pl.LightningModule):
|
||||||
|
"""Generalized Learning Vector Quantization."""
|
||||||
|
def __init__(self, lr=1e-3, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.lr = lr
|
||||||
|
self.proto_layer = Prototypes1D(**kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prototypes(self):
|
||||||
|
return self.proto_layer.prototypes.detach().numpy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prototype_labels(self):
|
||||||
|
return self.proto_layer.prototype_labels.detach().numpy()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
protos = self.proto_layer.prototypes
|
||||||
|
dis = euclidean_distance(x, protos)
|
||||||
|
return dis
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def training_step(self, train_batch, batch_idx):
|
||||||
|
x, y = train_batch
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
dis = self(x)
|
||||||
|
plabels = self.proto_layer.prototype_labels
|
||||||
|
mu = glvq_loss(dis, y, prototype_labels=plabels)
|
||||||
|
loss = mu.sum(dim=0)
|
||||||
|
self.log("train_loss", loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def predict(self, x):
|
||||||
|
with torch.no_grad():
|
||||||
|
d = self(x)
|
||||||
|
plabels = self.proto_layer.prototype_labels
|
||||||
|
y_pred = wtac(d, plabels)
|
||||||
|
return y_pred.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGLVQ(GLVQ):
|
||||||
|
"""GLVQ model that constrains the prototypes to the range [0, 1] by
|
||||||
|
clamping after updates.
|
||||||
|
"""
|
||||||
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
self.proto_layer.prototypes.data.clamp_(0., 1.)
|
Loading…
Reference in New Issue
Block a user