Add partial metric/hparam features [BROKEN STATE]
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
@@ -9,10 +10,11 @@ from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
class GLVQ(pl.LightningModule):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
def __init__(self, lr=1e-3, **kwargs):
|
||||
def __init__(self, hparams):
|
||||
super().__init__()
|
||||
self.lr = lr
|
||||
self.lr = hparams.lr
|
||||
self.proto_layer = Prototypes1D(**kwargs)
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
|
||||
@property
|
||||
def prototypes(self):
|
||||
@@ -39,10 +41,21 @@ class GLVQ(pl.LightningModule):
|
||||
mu = glvq_loss(dis, y, prototype_labels=plabels)
|
||||
loss = mu.sum(dim=0)
|
||||
self.log("train_loss", loss)
|
||||
with torch.no_grad():
|
||||
preds = wtac(dis, plabels)
|
||||
# self.train_acc.update(preds.int(), y.int())
|
||||
self.train_acc(preds.int(), y.int()) # FloatTensors are assumed to be class probabilities
|
||||
self.log("Training Accuracy", self.train_acc, on_step=False, on_epoch=True)
|
||||
return loss
|
||||
|
||||
# def training_epoch_end(self, outs):
|
||||
# # Calling `self.train_acc.compute()` is
|
||||
# # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)`
|
||||
# self.log("train_acc_epoch", self.train_acc.compute())
|
||||
|
||||
def predict(self, x):
|
||||
with torch.no_grad():
|
||||
# model.eval() # ?!
|
||||
d = self(x)
|
||||
plabels = self.proto_layer.prototype_labels
|
||||
y_pred = wtac(d, plabels)
|
||||
|
Reference in New Issue
Block a user