prototorch_models/prototorch/models/clcc/test_clcc.py

77 lines
2.0 KiB
Python
Raw Permalink Normal View History

2021-10-15 13:18:02 +00:00
from typing import Optional
import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl
import torch
2021-10-15 13:18:02 +00:00
import torchmetrics
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams
2021-10-15 13:18:02 +00:00
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.models.vis import Visualize2DVoronoiCallback
# NEW STUFF
# ##############################################################################
2021-10-15 13:18:02 +00:00
# TODO: Metrics
class MetricsTestCallback(pl.Callback):
metric_name = "test_cb_acc"
def setup(self,
trainer: pl.Trainer,
pl_module: CLCCScheme,
stage: Optional[str] = None) -> None:
pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy)
def on_epoch_end(self, trainer: pl.Trainer,
pl_module: pl.LightningModule) -> None:
metric = trainer.logged_metrics[self.metric_name]
if metric > 0.95:
trainer.should_stop = True
# TODO: Pruning
# ##############################################################################
if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
2021-10-15 13:18:02 +00:00
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=64,
num_workers=8)
components_initializer = SMCI(train_ds)
hparams = GLVQhparams(
distribution=dict(
num_classes=3,
per_class=2,
),
component_initializer=components_initializer,
)
model = GLVQ(hparams)
print(model)
# Callbacks
2021-10-15 13:18:02 +00:00
vis = Visualize2DVoronoiCallback(
data=train_ds,
resolution=500,
)
metrics = MetricsTestCallback()
# Train
2021-10-15 13:18:02 +00:00
trainer = pl.Trainer(
callbacks=[
#vis,
metrics,
],
gpus=1,
max_epochs=100,
weights_summary=None,
log_every_n_steps=1,
)
trainer.fit(model, train_loader)