diff --git a/examples/dynamic_pruning.py b/examples/dynamic_pruning.py index 0a2545f..2a4e81f 100644 --- a/examples/dynamic_pruning.py +++ b/examples/dynamic_pruning.py @@ -46,23 +46,23 @@ if __name__ == "__main__": vis = pt.models.VisGLVQ2D(train_ds) pruning = pt.models.PruneLoserPrototypes( threshold=0.01, # prune prototype if it wins less than 1% - idle_epochs=30, # pruning too early may cause problems - prune_quota_per_epoch=1, # prune at most 1 prototype per epoch - frequency=5, # prune every fifth epoch + idle_epochs=10, # pruning too early may cause problems + prune_quota_per_epoch=5, # prune at most 5 prototypes per epoch + frequency=2, # prune every second epoch verbose=True, ) - es = pt.models.EarlyStopWithoutVal( - monitor="loss", - min_delta=0.1, - patience=3, + + es = pl.callbacks.EarlyStopping( + monitor="train_loss", + min_delta=0.001, + patience=15, mode="min", - verbose=True, + check_on_train_epoch_end=True, ) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - max_epochs=250, callbacks=[ vis, pruning, diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index a539264..251370c 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -2,22 +2,10 @@ from importlib.metadata import PackageNotFoundError, version -from .callbacks import (EarlyStopWithoutVal, PrototypeConvergence, - PruneLoserPrototypes) +from .callbacks import PrototypeConvergence, PruneLoserPrototypes from .cbc import CBC, ImageCBC -from .glvq import ( - GLVQ, - GLVQ1, - GLVQ21, - GMLVQ, - GRLVQ, - LGMLVQ, - LVQMLN, - ImageGLVQ, - ImageGMLVQ, - SiameseGLVQ, - SiameseGMLVQ, -) +from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN, + ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ) from .lvq import LVQ1, LVQ21, MedianLVQ from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ from .unsupervised import KNN, NeuralGas diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index 48f40fa..d6643fb 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -4,22 +4,6 @@ import pytorch_lightning as pl import torch -class EarlyStopWithoutVal(pl.callbacks.EarlyStopping): - """Run early stopping at the end of training loop. - - See: - https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html - - """ - def on_validation_end(self, trainer, pl_module): - # override this to disable early stopping at the end of val loop - pass - - def on_train_end(self, trainer, pl_module): - # instead, do it at the end of training loop - self._run_early_stopping_check(trainer, pl_module) - - class PruneLoserPrototypes(pl.Callback): def __init__(self, threshold=0.01,