Add validation and test logic
This commit is contained in:
@@ -52,7 +52,7 @@ class SiamesePrototypeModel(pl.LightningModule):
|
||||
backbone.
|
||||
|
||||
"""
|
||||
# model.eval() # ?!
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
protos, plabels = self.proto_layer()
|
||||
if map_protos:
|
||||
|
@@ -32,7 +32,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||
|
||||
# Default Values
|
||||
self.hparams.setdefault("transfer_function", "identity")
|
||||
self.hparams.setdefault("transfer_fn", "identity")
|
||||
self.hparams.setdefault("transfer_beta", 10.0)
|
||||
self.hparams.setdefault("lr", 0.01)
|
||||
|
||||
@@ -40,8 +40,8 @@ class GLVQ(AbstractPrototypeModel):
|
||||
distribution=self.hparams.distribution,
|
||||
initializer=prototype_initializer)
|
||||
|
||||
self.transfer_function = get_activation(self.hparams.transfer_function)
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
self.transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||
self.acc_metric = torchmetrics.Accuracy()
|
||||
|
||||
self.loss = glvq_loss
|
||||
|
||||
@@ -54,18 +54,18 @@ class GLVQ(AbstractPrototypeModel):
|
||||
dis = self.distance_fn(x, protos)
|
||||
return dis
|
||||
|
||||
def log_acc(self, distances, targets):
|
||||
def log_acc(self, distances, targets, tag):
|
||||
plabels = self.proto_layer.component_labels
|
||||
|
||||
# Compute training accuracy
|
||||
with torch.no_grad():
|
||||
preds = wtac(distances, plabels)
|
||||
|
||||
self.train_acc(preds.int(), targets.int())
|
||||
self.acc_metric(preds.int(), targets.int())
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
|
||||
self.log("acc",
|
||||
self.train_acc,
|
||||
self.log(tag,
|
||||
self.acc_metric,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
@@ -76,18 +76,50 @@ class GLVQ(AbstractPrototypeModel):
|
||||
dis = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||
batch_loss = self.transfer_function(mu,
|
||||
train_batch_loss = self.transfer_fn(mu,
|
||||
beta=self.hparams.transfer_beta)
|
||||
loss = batch_loss.sum(dim=0)
|
||||
train_loss = train_batch_loss.sum(dim=0)
|
||||
|
||||
# Logging
|
||||
self.log("train_loss", loss)
|
||||
self.log_acc(dis, y)
|
||||
self.log("train_loss", train_loss)
|
||||
self.log_acc(dis, y, tag="train_acc")
|
||||
|
||||
return loss
|
||||
return train_loss
|
||||
|
||||
def validation_step(self, val_batch, batch_idx):
|
||||
# `model.eval()` and `torch.no_grad()` are called automatically for
|
||||
# validation.
|
||||
x, y = val_batch
|
||||
dis = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||
val_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
||||
val_loss = val_batch_loss.sum(dim=0)
|
||||
|
||||
# Logging
|
||||
self.log("val_loss", val_loss)
|
||||
self.log_acc(dis, y, tag="val_acc")
|
||||
|
||||
return val_loss
|
||||
|
||||
def test_step(self, test_batch, batch_idx):
|
||||
# `model.eval()` and `torch.no_grad()` are called automatically for
|
||||
# testing.
|
||||
x, y = test_batch
|
||||
dis = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||
test_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
||||
test_loss = test_batch_loss.sum(dim=0)
|
||||
|
||||
# Logging
|
||||
self.log("test_loss", test_loss)
|
||||
self.log_acc(dis, y, tag="test_acc")
|
||||
|
||||
return test_loss
|
||||
|
||||
def predict(self, x):
|
||||
# model.eval() # ?!
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
d = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
@@ -241,7 +273,7 @@ class LVQ1(NonGradientGLVQ):
|
||||
strict=False)
|
||||
|
||||
# Logging
|
||||
self.log_acc(dis, y)
|
||||
self.log_acc(dis, y, tag="train_acc")
|
||||
|
||||
return None
|
||||
|
||||
@@ -270,7 +302,7 @@ class LVQ21(NonGradientGLVQ):
|
||||
strict=False)
|
||||
|
||||
# Logging
|
||||
self.log_acc(dis, y)
|
||||
self.log_acc(dis, y, tag="train_acc")
|
||||
|
||||
return None
|
||||
|
||||
|
Reference in New Issue
Block a user