Use Components instead of Prototypes and refactor old examples
This commit is contained in:
@@ -1,14 +1,16 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.losses import glvq_loss
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
from .abstract import AbstractPrototypeModel
|
||||
|
||||
class GLVQ(pl.LightningModule):
|
||||
|
||||
class GLVQ(AbstractPrototypeModel):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__()
|
||||
@@ -18,29 +20,18 @@ class GLVQ(pl.LightningModule):
|
||||
# Default Values
|
||||
self.hparams.setdefault("distance", euclidean_distance)
|
||||
|
||||
self.proto_layer = Prototypes1D(
|
||||
input_dim=self.hparams.input_dim,
|
||||
nclasses=self.hparams.nclasses,
|
||||
prototypes_per_class=self.hparams.prototypes_per_class,
|
||||
prototype_initializer=self.hparams.prototype_initializer,
|
||||
**kwargs)
|
||||
self.proto_layer = LabeledComponents(
|
||||
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
|
||||
initializer=self.hparams.prototype_initializer)
|
||||
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
|
||||
@property
|
||||
def prototypes(self):
|
||||
return self.proto_layer.prototypes.detach().numpy()
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.prototype_labels.detach().numpy()
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
||||
return optimizer
|
||||
return self.proto_layer.component_labels.detach().numpy()
|
||||
|
||||
def forward(self, x):
|
||||
protos = self.proto_layer.prototypes
|
||||
protos, _ = self.proto_layer()
|
||||
dis = self.hparams.distance(x, protos)
|
||||
return dis
|
||||
|
||||
@@ -48,7 +39,7 @@ class GLVQ(pl.LightningModule):
|
||||
x, y = train_batch
|
||||
x = x.view(x.size(0), -1)
|
||||
dis = self(x)
|
||||
plabels = self.proto_layer.prototype_labels
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = glvq_loss(dis, y, prototype_labels=plabels)
|
||||
loss = mu.sum(dim=0)
|
||||
self.log("train_loss", loss)
|
||||
@@ -77,7 +68,7 @@ class GLVQ(pl.LightningModule):
|
||||
# model.eval() # ?!
|
||||
with torch.no_grad():
|
||||
d = self(x)
|
||||
plabels = self.proto_layer.prototype_labels
|
||||
plabels = self.proto_layer.component_labels
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred.numpy()
|
||||
|
||||
@@ -89,7 +80,7 @@ class ImageGLVQ(GLVQ):
|
||||
clamping after updates.
|
||||
"""
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
|
||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||
|
||||
|
||||
class SiameseGLVQ(GLVQ):
|
||||
@@ -115,7 +106,7 @@ class SiameseGLVQ(GLVQ):
|
||||
|
||||
def forward(self, x):
|
||||
self.sync_backbones()
|
||||
protos = self.proto_layer.prototypes
|
||||
protos, _ = self.proto_layer()
|
||||
|
||||
latent_x = self.backbone(x)
|
||||
latent_protos = self.backbone_dependent(protos)
|
||||
@@ -126,9 +117,8 @@ class SiameseGLVQ(GLVQ):
|
||||
def predict_latent(self, x):
|
||||
# model.eval() # ?!
|
||||
with torch.no_grad():
|
||||
protos = self.proto_layer.prototypes
|
||||
protos, plabels = self.proto_layer()
|
||||
latent_protos = self.backbone_dependent(protos)
|
||||
d = euclidean_distance(x, latent_protos)
|
||||
plabels = self.proto_layer.prototype_labels
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred.numpy()
|
||||
|
Reference in New Issue
Block a user