prototorch_models/prototorch/models/glvq.py

58 lines
1.8 KiB
Python
Raw Normal View History

2021-04-21 12:51:34 +00:00
import pytorch_lightning as pl
import torch
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.initializers import get_initializer
from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D
class GLVQ(pl.LightningModule):
"""Generalized Learning Vector Quantization."""
def __init__(self, lr=1e-3, **kwargs):
super().__init__()
self.lr = lr
self.proto_layer = Prototypes1D(**kwargs)
@property
def prototypes(self):
return self.proto_layer.prototypes.detach().numpy()
@property
def prototype_labels(self):
return self.proto_layer.prototype_labels.detach().numpy()
def forward(self, x):
protos = self.proto_layer.prototypes
dis = euclidean_distance(x, protos)
return dis
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
dis = self(x)
plabels = self.proto_layer.prototype_labels
mu = glvq_loss(dis, y, prototype_labels=plabels)
loss = mu.sum(dim=0)
self.log("train_loss", loss)
return loss
def predict(self, x):
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.prototype_labels
y_pred = wtac(d, plabels)
return y_pred.numpy()
class ImageGLVQ(GLVQ):
"""GLVQ model that constrains the prototypes to the range [0, 1] by
clamping after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.prototypes.data.clamp_(0., 1.)