prototorch_models/prototorch/models/clcc/test_clcc.py

104 lines
3.0 KiB
Python
Raw Normal View History

2022-05-17 14:25:43 +00:00
from typing import Optional
import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.models.vis import Visualize2DVoronoiCallback
# NEW STUFF
# ##############################################################################
# TODO: Metrics
class MetricsTestCallback(pl.Callback):
metric_name = "test_cb_acc"
def setup(self,
trainer: pl.Trainer,
pl_module: CLCCScheme,
stage: Optional[str] = None) -> None:
pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy)
def on_epoch_end(self, trainer: pl.Trainer,
pl_module: pl.LightningModule) -> None:
metric = trainer.logged_metrics[self.metric_name]
if metric > 0.95:
trainer.should_stop = True
class LogTorchmetricCallback(pl.Callback):
def __init__(self, name, metric, on="prediction", **metric_args) -> None:
self.name = name
self.metric = metric
self.metric_args = metric_args
self.on = on
def setup(self,
trainer: pl.Trainer,
pl_module: CLCCScheme,
stage: Optional[str] = None) -> None:
if self.on == "prediction":
pl_module.register_torchmetric(self.name, self.metric,
**self.metric_args)
else:
raise ValueError(f"{self.on} is no valid metric hook")
# TODO: Pruning
# ##############################################################################
if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
train_ds.targets[train_ds.targets == 2.0] = 1.0
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=64,
num_workers=0,
shuffle=True)
#components_initializer = SMCI(train_ds)
components_initializer = RandomNormalCompInitializer(2)
hparams = GLVQhparams(
lr=0.5,
distribution=dict(
num_classes=2,
per_class=1,
),
component_initializer=components_initializer,
)
model = GLVQ(hparams)
print(model)
# Callbacks
vis = Visualize2DVoronoiCallback(
data=train_ds,
resolution=500,
)
metrics = MetricsTestCallback()
recall = LogTorchmetricCallback('recall',
torchmetrics.Recall,
num_classes=2)
# Train
trainer = pl.Trainer(
callbacks=[
vis,
#metrics,
recall,
],
gpus=0,
max_epochs=200,
weights_summary=None,
log_every_n_steps=1,
)
trainer.fit(model, train_loader)