Add validation and test logic
This commit is contained in:
parent
7700bb7f8d
commit
fdf9443a2c
@ -7,14 +7,14 @@ import torch
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Tecator(root="~/datasets/", train=True)
|
train_ds = pt.datasets.Tecator(root="~/datasets/", train=True)
|
||||||
|
test_ds = pt.datasets.Tecator(root="~/datasets/", train=False)
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
pl.utilities.seed.seed_everything(seed=42)
|
pl.utilities.seed.seed_everything(seed=42)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
|
||||||
num_workers=0,
|
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32)
|
||||||
batch_size=32)
|
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
nclasses = 2
|
nclasses = 2
|
||||||
@ -23,8 +23,8 @@ if __name__ == "__main__":
|
|||||||
distribution=(nclasses, prototypes_per_class),
|
distribution=(nclasses, prototypes_per_class),
|
||||||
input_dim=100,
|
input_dim=100,
|
||||||
latent_dim=2,
|
latent_dim=2,
|
||||||
proto_lr=0.001,
|
proto_lr=0.005,
|
||||||
bb_lr=0.001,
|
bb_lr=0.005,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
@ -35,10 +35,15 @@ if __name__ == "__main__":
|
|||||||
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
|
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(max_epochs=200, callbacks=[vis], gpus=0)
|
trainer = pl.Trainer(
|
||||||
|
gpus=0,
|
||||||
|
max_epochs=20,
|
||||||
|
callbacks=[vis],
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader, test_loader)
|
||||||
|
|
||||||
# Save the model
|
# Save the model
|
||||||
torch.save(model, "liramlvq_tecator.pt")
|
torch.save(model, "liramlvq_tecator.pt")
|
||||||
@ -48,3 +53,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Display the Lambda matrix
|
# Display the Lambda matrix
|
||||||
saved_model.show_lambda()
|
saved_model.show_lambda()
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
# TODO
|
||||||
|
# trainer.test(model, test_dataloaders=test_loader)
|
||||||
|
@ -48,7 +48,7 @@ if __name__ == "__main__":
|
|||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
prototype_initializer=pt.components.SMI(train_ds),
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
both_path_gradients=True,
|
both_path_gradients=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Model summary
|
# Model summary
|
||||||
|
@ -52,7 +52,7 @@ class SiamesePrototypeModel(pl.LightningModule):
|
|||||||
backbone.
|
backbone.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# model.eval() # ?!
|
self.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
protos, plabels = self.proto_layer()
|
protos, plabels = self.proto_layer()
|
||||||
if map_protos:
|
if map_protos:
|
||||||
|
@ -32,7 +32,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||||
|
|
||||||
# Default Values
|
# Default Values
|
||||||
self.hparams.setdefault("transfer_function", "identity")
|
self.hparams.setdefault("transfer_fn", "identity")
|
||||||
self.hparams.setdefault("transfer_beta", 10.0)
|
self.hparams.setdefault("transfer_beta", 10.0)
|
||||||
self.hparams.setdefault("lr", 0.01)
|
self.hparams.setdefault("lr", 0.01)
|
||||||
|
|
||||||
@ -40,8 +40,8 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
distribution=self.hparams.distribution,
|
distribution=self.hparams.distribution,
|
||||||
initializer=prototype_initializer)
|
initializer=prototype_initializer)
|
||||||
|
|
||||||
self.transfer_function = get_activation(self.hparams.transfer_function)
|
self.transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||||
self.train_acc = torchmetrics.Accuracy()
|
self.acc_metric = torchmetrics.Accuracy()
|
||||||
|
|
||||||
self.loss = glvq_loss
|
self.loss = glvq_loss
|
||||||
|
|
||||||
@ -54,18 +54,18 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
dis = self.distance_fn(x, protos)
|
dis = self.distance_fn(x, protos)
|
||||||
return dis
|
return dis
|
||||||
|
|
||||||
def log_acc(self, distances, targets):
|
def log_acc(self, distances, targets, tag):
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
|
|
||||||
# Compute training accuracy
|
# Compute training accuracy
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
preds = wtac(distances, plabels)
|
preds = wtac(distances, plabels)
|
||||||
|
|
||||||
self.train_acc(preds.int(), targets.int())
|
self.acc_metric(preds.int(), targets.int())
|
||||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||||
|
|
||||||
self.log("acc",
|
self.log(tag,
|
||||||
self.train_acc,
|
self.acc_metric,
|
||||||
on_step=False,
|
on_step=False,
|
||||||
on_epoch=True,
|
on_epoch=True,
|
||||||
prog_bar=True,
|
prog_bar=True,
|
||||||
@ -76,18 +76,50 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
dis = self(x)
|
dis = self(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
mu = self.loss(dis, y, prototype_labels=plabels)
|
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||||
batch_loss = self.transfer_function(mu,
|
train_batch_loss = self.transfer_fn(mu,
|
||||||
beta=self.hparams.transfer_beta)
|
beta=self.hparams.transfer_beta)
|
||||||
loss = batch_loss.sum(dim=0)
|
train_loss = train_batch_loss.sum(dim=0)
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
self.log("train_loss", loss)
|
self.log("train_loss", train_loss)
|
||||||
self.log_acc(dis, y)
|
self.log_acc(dis, y, tag="train_acc")
|
||||||
|
|
||||||
return loss
|
return train_loss
|
||||||
|
|
||||||
|
def validation_step(self, val_batch, batch_idx):
|
||||||
|
# `model.eval()` and `torch.no_grad()` are called automatically for
|
||||||
|
# validation.
|
||||||
|
x, y = val_batch
|
||||||
|
dis = self(x)
|
||||||
|
plabels = self.proto_layer.component_labels
|
||||||
|
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||||
|
val_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
||||||
|
val_loss = val_batch_loss.sum(dim=0)
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
self.log("val_loss", val_loss)
|
||||||
|
self.log_acc(dis, y, tag="val_acc")
|
||||||
|
|
||||||
|
return val_loss
|
||||||
|
|
||||||
|
def test_step(self, test_batch, batch_idx):
|
||||||
|
# `model.eval()` and `torch.no_grad()` are called automatically for
|
||||||
|
# testing.
|
||||||
|
x, y = test_batch
|
||||||
|
dis = self(x)
|
||||||
|
plabels = self.proto_layer.component_labels
|
||||||
|
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||||
|
test_batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
||||||
|
test_loss = test_batch_loss.sum(dim=0)
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
self.log("test_loss", test_loss)
|
||||||
|
self.log_acc(dis, y, tag="test_acc")
|
||||||
|
|
||||||
|
return test_loss
|
||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x):
|
||||||
# model.eval() # ?!
|
self.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
d = self(x)
|
d = self(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
@ -241,7 +273,7 @@ class LVQ1(NonGradientGLVQ):
|
|||||||
strict=False)
|
strict=False)
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
self.log_acc(dis, y)
|
self.log_acc(dis, y, tag="train_acc")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -270,7 +302,7 @@ class LVQ21(NonGradientGLVQ):
|
|||||||
strict=False)
|
strict=False)
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
self.log_acc(dis, y)
|
self.log_acc(dis, y, tag="train_acc")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user