Add validation and test logic

This commit is contained in:
Jensun Ravichandran 2021-05-19 16:30:19 +02:00
parent 7700bb7f8d
commit fdf9443a2c
4 changed files with 65 additions and 24 deletions

View File

@ -7,14 +7,14 @@ import torch
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
train_ds = pt.datasets.Tecator(root="~/datasets/", train=True) train_ds = pt.datasets.Tecator(root="~/datasets/", train=True)
test_ds = pt.datasets.Tecator(root="~/datasets/", train=False)
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=42) pl.utilities.seed.seed_everything(seed=42)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
num_workers=0, test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32)
batch_size=32)
# Hyperparameters # Hyperparameters
nclasses = 2 nclasses = 2
@ -23,8 +23,8 @@ if __name__ == "__main__":
distribution=(nclasses, prototypes_per_class), distribution=(nclasses, prototypes_per_class),
input_dim=100, input_dim=100,
latent_dim=2, latent_dim=2,
proto_lr=0.001, proto_lr=0.005,
bb_lr=0.001, bb_lr=0.005,
) )
# Initialize the model # Initialize the model
@ -35,10 +35,15 @@ if __name__ == "__main__":
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1) vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=200, callbacks=[vis], gpus=0) trainer = pl.Trainer(
gpus=0,
max_epochs=20,
callbacks=[vis],
weights_summary=None,
)
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader, test_loader)
# Save the model # Save the model
torch.save(model, "liramlvq_tecator.pt") torch.save(model, "liramlvq_tecator.pt")
@ -48,3 +53,7 @@ if __name__ == "__main__":
# Display the Lambda matrix # Display the Lambda matrix
saved_model.show_lambda() saved_model.show_lambda()
# Testing
# TODO
# trainer.test(model, test_dataloaders=test_loader)

View File

@ -48,7 +48,7 @@ if __name__ == "__main__":
hparams, hparams,
prototype_initializer=pt.components.SMI(train_ds), prototype_initializer=pt.components.SMI(train_ds),
backbone=backbone, backbone=backbone,
both_path_gradients=True, both_path_gradients=False,
) )
# Model summary # Model summary

View File

@ -52,7 +52,7 @@ class SiamesePrototypeModel(pl.LightningModule):
backbone. backbone.
""" """
# model.eval() # ?! self.eval()
with torch.no_grad(): with torch.no_grad():
protos, plabels = self.proto_layer() protos, plabels = self.proto_layer()
if map_protos: if map_protos:

View File

@ -32,7 +32,7 @@ class GLVQ(AbstractPrototypeModel):
prototype_initializer = kwargs.get("prototype_initializer", None) prototype_initializer = kwargs.get("prototype_initializer", None)
# Default Values # Default Values
self.hparams.setdefault("transfer_function", "identity") self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0) self.hparams.setdefault("transfer_beta", 10.0)
self.hparams.setdefault("lr", 0.01) self.hparams.setdefault("lr", 0.01)
@ -40,8 +40,8 @@ class GLVQ(AbstractPrototypeModel):
distribution=self.hparams.distribution, distribution=self.hparams.distribution,
initializer=prototype_initializer) initializer=prototype_initializer)
self.transfer_function = get_activation(self.hparams.transfer_function) self.transfer_fn = get_activation(self.hparams.transfer_fn)
self.train_acc = torchmetrics.Accuracy() self.acc_metric = torchmetrics.Accuracy()
self.loss = glvq_loss self.loss = glvq_loss
@ -54,18 +54,18 @@ class GLVQ(AbstractPrototypeModel):
dis = self.distance_fn(x, protos) dis = self.distance_fn(x, protos)
return dis return dis
def log_acc(self, distances, targets): def log_acc(self, distances, targets, tag):
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
# Compute training accuracy # Compute training accuracy
with torch.no_grad(): with torch.no_grad():
preds = wtac(distances, plabels) preds = wtac(distances, plabels)
self.train_acc(preds.int(), targets.int()) self.acc_metric(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities # `.int()` because FloatTensors are assumed to be class probabilities
self.log("acc", self.log(tag,
self.train_acc, self.acc_metric,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
@ -76,18 +76,50 @@ class GLVQ(AbstractPrototypeModel):
dis = self(x) dis = self(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels) mu = self.loss(dis, y, prototype_labels=plabels)
batch_loss = self.transfer_function(mu, train_batch_loss = self.transfer_fn(mu,
beta=self.hparams.transfer_beta) beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0) train_loss = train_batch_loss.sum(dim=0)
# Logging # Logging
self.log("train_loss", loss) self.log("train_loss", train_loss)
self.log_acc(dis, y) self.log_acc(dis, y, tag="train_acc")
return loss return train_loss
def validation_step(self, val_batch, batch_idx):
# `model.eval()` and `torch.no_grad()` are called automatically for
# validation.
x, y = val_batch
dis = self(x)
plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels)
val_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
val_loss = val_batch_loss.sum(dim=0)
# Logging
self.log("val_loss", val_loss)
self.log_acc(dis, y, tag="val_acc")
return val_loss
def test_step(self, test_batch, batch_idx):
# `model.eval()` and `torch.no_grad()` are called automatically for
# testing.
x, y = test_batch
dis = self(x)
plabels = self.proto_layer.component_labels
mu = self.loss(dis, y, prototype_labels=plabels)
test_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
test_loss = test_batch_loss.sum(dim=0)
# Logging
self.log("test_loss", test_loss)
self.log_acc(dis, y, tag="test_acc")
return test_loss
def predict(self, x): def predict(self, x):
# model.eval() # ?! self.eval()
with torch.no_grad(): with torch.no_grad():
d = self(x) d = self(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
@ -241,7 +273,7 @@ class LVQ1(NonGradientGLVQ):
strict=False) strict=False)
# Logging # Logging
self.log_acc(dis, y) self.log_acc(dis, y, tag="train_acc")
return None return None
@ -270,7 +302,7 @@ class LVQ21(NonGradientGLVQ):
strict=False) strict=False)
# Logging # Logging
self.log_acc(dis, y) self.log_acc(dis, y, tag="train_acc")
return None return None