feat: add simple test step

This commit is contained in:
Alexander Engelsberger 2021-09-10 19:19:51 +02:00
parent fa928afe2c
commit d7ea89d47e
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
2 changed files with 26 additions and 6 deletions

View File

@ -6,6 +6,7 @@ import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -14,12 +15,20 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
x_train, y_train = load_iris(return_X_y=True) X, y = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]] X = X[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
X_train, X_test, y_train, y_test = train_test_split(X,
y,
test_size=0.5,
random_state=42)
train_ds = pt.datasets.NumpyDataset(X_train, y_train)
test_ds = pt.datasets.NumpyDataset(X_test, y_test)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=16)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=16)
# Hyperparameters # Hyperparameters
hparams = dict(k=5) hparams = dict(k=5)
@ -35,7 +44,7 @@ if __name__ == "__main__":
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D( vis = pt.models.VisGLVQ2D(
data=(x_train, y_train), data=(X_train, y_train),
resolution=200, resolution=200,
block=True, block=True,
) )
@ -53,5 +62,8 @@ if __name__ == "__main__":
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
# Recall # Recall
y_pred = model.predict(torch.tensor(x_train)) y_pred = model.predict(torch.tensor(X_train))
print(y_pred) print(y_pred)
# Test
trainer.test(model, dataloaders=test_loader)

View File

@ -162,6 +162,14 @@ class SupervisedPrototypeModel(PrototypeModel):
prog_bar=True, prog_bar=True,
logger=True) logger=True)
def test_step(self, batch, batch_idx):
x, targets = batch
preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
self.log("test_acc", accuracy)
class ProtoTorchMixin(object): class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins.""" """All mixins are ProtoTorchMixins."""