chore: improve clc-lc test
This commit is contained in:
		| @@ -1,5 +1,5 @@ | ||||
| from dataclasses import dataclass | ||||
| from typing import Callable | ||||
| from typing import Callable, Type | ||||
|  | ||||
| import torch | ||||
| from prototorch.core.competitions import WTAC | ||||
| @@ -14,40 +14,48 @@ from prototorch.models.clcc.clcc_scheme import CLCCScheme | ||||
| from prototorch.nn.wrappers import LambdaLayer | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class GLVQhparams: | ||||
|     distribution: dict | ||||
|     component_initializer: AbstractComponentsInitializer | ||||
|     distance_fn: Callable = euclidean_distance | ||||
|     lr: float = 0.01 | ||||
|     margin: float = 0.0 | ||||
|     # TODO: make nicer | ||||
|     transfer_fn: str = "identity" | ||||
|     transfer_beta: float = 10.0 | ||||
|     optimizer: torch.optim.Optimizer = torch.optim.Adam | ||||
| class SupervisedScheme(CLCCScheme): | ||||
|  | ||||
|     @dataclass | ||||
|     class HyperParameters: | ||||
|         distribution: dict[str, int] | ||||
|         component_initializer: AbstractComponentsInitializer | ||||
|  | ||||
| class GLVQ(CLCCScheme): | ||||
|  | ||||
|     def __init__(self, hparams: GLVQhparams) -> None: | ||||
|         super().__init__(hparams) | ||||
|         self.lr = hparams.lr | ||||
|         self.optimizer = hparams.optimizer | ||||
|  | ||||
|     # Initializers | ||||
|     def init_components(self, hparams): | ||||
|         # initialize Component Layer | ||||
|     def init_components(self, hparams: HyperParameters): | ||||
|         self.components_layer = LabeledComponents( | ||||
|             distribution=hparams.distribution, | ||||
|             components_initializer=hparams.component_initializer, | ||||
|             labels_initializer=LabelsInitializer(), | ||||
|         ) | ||||
|  | ||||
|     def init_comparison(self, hparams): | ||||
|         # initialize Distance Layer | ||||
|  | ||||
| # ############################################################################## | ||||
| # GLVQ | ||||
| # ############################################################################## | ||||
| class GLVQ( | ||||
|         SupervisedScheme, ): | ||||
|     """GLVQ using the new Scheme | ||||
|     """ | ||||
|  | ||||
|     @dataclass | ||||
|     class HyperParameters(SupervisedScheme.HyperParameters): | ||||
|         distance_fn: Callable = euclidean_distance | ||||
|         lr: float = 0.01 | ||||
|         margin: float = 0.0 | ||||
|         # TODO: make nicer | ||||
|         transfer_fn: str = "identity" | ||||
|         transfer_beta: float = 10.0 | ||||
|         optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam | ||||
|  | ||||
|     def __init__(self, hparams: HyperParameters) -> None: | ||||
|         super().__init__(hparams) | ||||
|         self.lr = hparams.lr | ||||
|         self.optimizer = hparams.optimizer | ||||
|  | ||||
|     def init_comparison(self, hparams: HyperParameters): | ||||
|         self.comparison_layer = LambdaLayer(hparams.distance_fn) | ||||
|  | ||||
|     def init_inference(self, hparams): | ||||
|     def init_inference(self, hparams: HyperParameters): | ||||
|         self.competition_layer = WTAC() | ||||
|  | ||||
|     def init_loss(self, hparams): | ||||
| @@ -78,7 +86,7 @@ class GLVQ(CLCCScheme): | ||||
|         return self.loss_layer(comparisonmeasures, target, comp_labels) | ||||
|  | ||||
|     def configure_optimizers(self): | ||||
|         return self.optimizer(self.parameters(), lr=self.lr) | ||||
|         return self.optimizer(self.parameters(), lr=self.lr)  # type: ignore | ||||
|  | ||||
|     # Properties | ||||
|     @property | ||||
|   | ||||
| @@ -8,6 +8,7 @@ CLCC is a LVQ scheme containing 4 steps | ||||
| - Competition | ||||
|  | ||||
| """ | ||||
| from dataclasses import dataclass | ||||
| from typing import ( | ||||
|     Dict, | ||||
|     Set, | ||||
| @@ -20,9 +21,16 @@ from torchmetrics import Accuracy, Metric | ||||
|  | ||||
|  | ||||
| class CLCCScheme(pl.LightningModule): | ||||
|  | ||||
|     @dataclass | ||||
|     class HyperParameters: | ||||
|         ... | ||||
|  | ||||
|     registered_metrics: Dict[Type[Metric], Metric] = {} | ||||
|     registered_metric_names: Dict[Type[Metric], Set[str]] = {} | ||||
|  | ||||
|     components_layer: pl.LightningModule | ||||
|  | ||||
|     def __init__(self, hparams) -> None: | ||||
|         super().__init__() | ||||
|  | ||||
| @@ -42,9 +50,14 @@ class CLCCScheme(pl.LightningModule): | ||||
|         self.init_model_metrics() | ||||
|  | ||||
|     # internal API, called by models and callbacks | ||||
|     def register_torchmetric(self, name: str, metric: Metric, **metricargs): | ||||
|     def register_torchmetric( | ||||
|         self, | ||||
|         name: str, | ||||
|         metric: Type[Metric], | ||||
|         **metric_kwargs, | ||||
|     ): | ||||
|         if metric not in self.registered_metrics: | ||||
|             self.registered_metrics[metric] = metric(**metricargs) | ||||
|             self.registered_metrics[metric] = metric(**metric_kwargs) | ||||
|             self.registered_metric_names[metric] = {name} | ||||
|         else: | ||||
|             self.registered_metric_names[metric].add(name) | ||||
| @@ -92,25 +105,25 @@ class CLCCScheme(pl.LightningModule): | ||||
|     # Empty Initialization | ||||
|     # TODO: Type hints | ||||
|     # TODO: Docs | ||||
|     def init_components(self, hparams): | ||||
|     def init_components(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_latent(self, hparams): | ||||
|     def init_latent(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_comparison(self, hparams): | ||||
|     def init_comparison(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_competition(self, hparams): | ||||
|     def init_competition(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_loss(self, hparams): | ||||
|     def init_loss(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_inference(self, hparams): | ||||
|     def init_inference(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_model_metrics(self): | ||||
|     def init_model_metrics(self) -> None: | ||||
|         self.register_torchmetric('accuracy', Accuracy) | ||||
|  | ||||
|     # Empty Steps | ||||
|   | ||||
| @@ -1,55 +1,76 @@ | ||||
| from typing import Optional | ||||
| from typing import Optional, Type | ||||
|  | ||||
| import matplotlib.pyplot as plt | ||||
| import numpy as np | ||||
| import prototorch as pt | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| import torchmetrics | ||||
| from prototorch.core.initializers import SMCI, RandomNormalCompInitializer | ||||
| from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams | ||||
| from prototorch.core import SMCI | ||||
| from prototorch.models.clcc.clcc_glvq import GLVQ | ||||
| from prototorch.models.clcc.clcc_scheme import CLCCScheme | ||||
| from prototorch.models.vis import Visualize2DVoronoiCallback | ||||
| from prototorch.models.vis import Vis2DAbstract | ||||
| from prototorch.utils.utils import mesh2d | ||||
| from pytorch_lightning.callbacks import EarlyStopping | ||||
| from torch.utils.data import DataLoader | ||||
|  | ||||
| # NEW STUFF | ||||
| # ############################################################################## | ||||
|  | ||||
|  | ||||
| # TODO: Metrics | ||||
| class MetricsTestCallback(pl.Callback): | ||||
|     metric_name = "test_cb_acc" | ||||
|  | ||||
|     def setup(self, | ||||
|               trainer: pl.Trainer, | ||||
|               pl_module: CLCCScheme, | ||||
|               stage: Optional[str] = None) -> None: | ||||
|         pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy) | ||||
|  | ||||
|     def on_epoch_end(self, trainer: pl.Trainer, | ||||
|                      pl_module: pl.LightningModule) -> None: | ||||
|         metric = trainer.logged_metrics[self.metric_name] | ||||
|         if metric > 0.95: | ||||
|             trainer.should_stop = True | ||||
|  | ||||
|  | ||||
| class LogTorchmetricCallback(pl.Callback): | ||||
|  | ||||
|     def __init__(self, name, metric, on="prediction", **metric_args) -> None: | ||||
|     def __init__( | ||||
|         self, | ||||
|         name, | ||||
|         metric: Type[torchmetrics.Metric], | ||||
|         on="prediction", | ||||
|         **metric_kwargs, | ||||
|     ) -> None: | ||||
|         self.name = name | ||||
|         self.metric = metric | ||||
|         self.metric_args = metric_args | ||||
|         self.metric_kwargs = metric_kwargs | ||||
|         self.on = on | ||||
|  | ||||
|     def setup(self, | ||||
|               trainer: pl.Trainer, | ||||
|               pl_module: CLCCScheme, | ||||
|               stage: Optional[str] = None) -> None: | ||||
|     def setup( | ||||
|         self, | ||||
|         trainer: pl.Trainer, | ||||
|         pl_module: CLCCScheme, | ||||
|         stage: Optional[str] = None, | ||||
|     ) -> None: | ||||
|         if self.on == "prediction": | ||||
|             pl_module.register_torchmetric(self.name, self.metric, | ||||
|                                            **self.metric_args) | ||||
|             pl_module.register_torchmetric( | ||||
|                 self.name, | ||||
|                 self.metric, | ||||
|                 **self.metric_kwargs, | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError(f"{self.on} is no valid metric hook") | ||||
|  | ||||
|  | ||||
| class VisGLVQ2D(Vis2DAbstract): | ||||
|  | ||||
|     def visualize(self, pl_module): | ||||
|         protos = pl_module.prototypes | ||||
|         plabels = pl_module.prototype_labels | ||||
|         x_train, y_train = self.x_train, self.y_train | ||||
|         ax = self.setup_ax() | ||||
|         self.plot_protos(ax, protos, plabels) | ||||
|         if x_train is not None: | ||||
|             self.plot_data(ax, x_train, y_train) | ||||
|             mesh_input, xx, yy = mesh2d( | ||||
|                 np.vstack([x_train, protos]), | ||||
|                 self.border, | ||||
|                 self.resolution, | ||||
|             ) | ||||
|         else: | ||||
|             mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution) | ||||
|         _components = pl_module.components_layer.components | ||||
|         mesh_input = torch.from_numpy(mesh_input).type_as(_components) | ||||
|         y_pred = pl_module.predict(mesh_input) | ||||
|         y_pred = y_pred.cpu().reshape(xx.shape) | ||||
|         ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) | ||||
|  | ||||
|  | ||||
| # TODO: Pruning | ||||
|  | ||||
| # ############################################################################## | ||||
| @@ -59,15 +80,17 @@ if __name__ == "__main__": | ||||
|     train_ds = pt.datasets.Iris(dims=[0, 2]) | ||||
|     train_ds.targets[train_ds.targets == 2.0] = 1.0 | ||||
|     # Dataloaders | ||||
|     train_loader = torch.utils.data.DataLoader(train_ds, | ||||
|                                                batch_size=64, | ||||
|                                                num_workers=0, | ||||
|                                                shuffle=True) | ||||
|     train_loader = DataLoader( | ||||
|         train_ds, | ||||
|         batch_size=64, | ||||
|         num_workers=0, | ||||
|         shuffle=True, | ||||
|     ) | ||||
|  | ||||
|     #components_initializer = SMCI(train_ds) | ||||
|     components_initializer = RandomNormalCompInitializer(2) | ||||
|     components_initializer = SMCI(train_ds) | ||||
|     #components_initializer = RandomNormalCompInitializer(2) | ||||
|  | ||||
|     hparams = GLVQhparams( | ||||
|     hyperparameters = GLVQ.HyperParameters( | ||||
|         lr=0.5, | ||||
|         distribution=dict( | ||||
|             num_classes=2, | ||||
| @@ -75,29 +98,36 @@ if __name__ == "__main__": | ||||
|         ), | ||||
|         component_initializer=components_initializer, | ||||
|     ) | ||||
|     model = GLVQ(hparams) | ||||
|  | ||||
|     model = GLVQ(hyperparameters) | ||||
|  | ||||
|     print(model) | ||||
|  | ||||
|     # Callbacks | ||||
|     vis = Visualize2DVoronoiCallback( | ||||
|         data=train_ds, | ||||
|         resolution=500, | ||||
|     vis = VisGLVQ2D(data=train_ds) | ||||
|     recall = LogTorchmetricCallback( | ||||
|         'recall', | ||||
|         torchmetrics.Recall, | ||||
|         num_classes=2, | ||||
|     ) | ||||
|  | ||||
|     es = EarlyStopping( | ||||
|         monitor="recall", | ||||
|         min_delta=0.001, | ||||
|         patience=15, | ||||
|         mode="max", | ||||
|         check_on_train_epoch_end=True, | ||||
|     ) | ||||
|     metrics = MetricsTestCallback() | ||||
|     recall = LogTorchmetricCallback('recall', | ||||
|                                     torchmetrics.Recall, | ||||
|                                     num_classes=2) | ||||
|  | ||||
|     # Train | ||||
|     trainer = pl.Trainer( | ||||
|         callbacks=[ | ||||
|             vis, | ||||
|             #metrics, | ||||
|             recall, | ||||
|             es, | ||||
|         ], | ||||
|         gpus=0, | ||||
|         max_epochs=200, | ||||
|         weights_summary=None, | ||||
|         log_every_n_steps=1, | ||||
|     ) | ||||
|     trainer.fit(model, train_loader) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user