feat: add simple test step
This commit is contained in:
		| @@ -6,6 +6,7 @@ import prototorch as pt | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| from sklearn.datasets import load_iris | ||||
| from sklearn.model_selection import train_test_split | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     # Command-line arguments | ||||
| @@ -14,12 +15,20 @@ if __name__ == "__main__": | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     # Dataset | ||||
|     x_train, y_train = load_iris(return_X_y=True) | ||||
|     x_train = x_train[:, [0, 2]] | ||||
|     train_ds = pt.datasets.NumpyDataset(x_train, y_train) | ||||
|     X, y = load_iris(return_X_y=True) | ||||
|     X = X[:, [0, 2]] | ||||
|  | ||||
|     X_train, X_test, y_train, y_test = train_test_split(X, | ||||
|                                                         y, | ||||
|                                                         test_size=0.5, | ||||
|                                                         random_state=42) | ||||
|  | ||||
|     train_ds = pt.datasets.NumpyDataset(X_train, y_train) | ||||
|     test_ds = pt.datasets.NumpyDataset(X_test, y_test) | ||||
|  | ||||
|     # Dataloaders | ||||
|     train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) | ||||
|     train_loader = torch.utils.data.DataLoader(train_ds, batch_size=16) | ||||
|     test_loader = torch.utils.data.DataLoader(test_ds, batch_size=16) | ||||
|  | ||||
|     # Hyperparameters | ||||
|     hparams = dict(k=5) | ||||
| @@ -35,7 +44,7 @@ if __name__ == "__main__": | ||||
|  | ||||
|     # Callbacks | ||||
|     vis = pt.models.VisGLVQ2D( | ||||
|         data=(x_train, y_train), | ||||
|         data=(X_train, y_train), | ||||
|         resolution=200, | ||||
|         block=True, | ||||
|     ) | ||||
| @@ -53,5 +62,8 @@ if __name__ == "__main__": | ||||
|     trainer.fit(model, train_loader) | ||||
|  | ||||
|     # Recall | ||||
|     y_pred = model.predict(torch.tensor(x_train)) | ||||
|     y_pred = model.predict(torch.tensor(X_train)) | ||||
|     print(y_pred) | ||||
|  | ||||
|     # Test | ||||
|     trainer.test(model, dataloaders=test_loader) | ||||
|   | ||||
| @@ -162,6 +162,14 @@ class SupervisedPrototypeModel(PrototypeModel): | ||||
|                  prog_bar=True, | ||||
|                  logger=True) | ||||
|  | ||||
|     def test_step(self, batch, batch_idx): | ||||
|         x, targets = batch | ||||
|  | ||||
|         preds = self.predict(x) | ||||
|         accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) | ||||
|  | ||||
|         self.log("test_acc", accuracy) | ||||
|  | ||||
|  | ||||
| class ProtoTorchMixin(object): | ||||
|     """All mixins are ProtoTorchMixins.""" | ||||
|   | ||||
		Reference in New Issue
	
	Block a user