From ef6bcc1079e8306880820e018470504654738725 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 2 Jun 2021 12:44:34 +0200 Subject: [PATCH] [BUG] Early stopping does not seem to work The early stopping callback does not work as expected, and crashes at the end of max_epochs with: ``` ~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py in on_train_end(self) 155 """Called when the train ends.""" 156 for callback in self.callbacks: --> 157 callback.on_train_end(self, self.lightning_module) 158 159 def on_pretrain_routine_start(self) -> None: ~/work/repos/prototorch_models/prototorch/models/callbacks.py in on_train_end(self, trainer, pl_module) 18 def on_train_end(self, trainer, pl_module): 19 # instead, do it at the end of training loop ---> 20 self._run_early_stopping_check(trainer, pl_module) 21 22 TypeError: _run_early_stopping_check() takes 2 positional arguments but 3 were given ``` --- examples/dynamic_pruning.py | 12 +++++++++-- prototorch/models/__init__.py | 3 ++- prototorch/models/callbacks.py | 37 +++++++++++++++++++++++++++++++--- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/examples/dynamic_pruning.py b/examples/dynamic_pruning.py index e9f8946..0a2545f 100644 --- a/examples/dynamic_pruning.py +++ b/examples/dynamic_pruning.py @@ -46,19 +46,27 @@ if __name__ == "__main__": vis = pt.models.VisGLVQ2D(train_ds) pruning = pt.models.PruneLoserPrototypes( threshold=0.01, # prune prototype if it wins less than 1% - prune_after_epochs=30, # pruning too early may cause problems + idle_epochs=30, # pruning too early may cause problems prune_quota_per_epoch=1, # prune at most 1 prototype per epoch frequency=5, # prune every fifth epoch verbose=True, ) + es = pt.models.EarlyStopWithoutVal( + monitor="loss", + min_delta=0.1, + patience=3, + mode="min", + verbose=True, + ) # Setup trainer trainer = pl.Trainer.from_argparse_args( args, - max_epochs=100, + max_epochs=250, callbacks=[ vis, pruning, + es, ], terminate_on_nan=True, weights_summary=None, diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index c7794e2..a539264 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -2,7 +2,8 @@ from importlib.metadata import PackageNotFoundError, version -from .callbacks import PruneLoserPrototypes +from .callbacks import (EarlyStopWithoutVal, PrototypeConvergence, + PruneLoserPrototypes) from .cbc import CBC, ImageCBC from .glvq import ( GLVQ, diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index 8f8a7c4..48f40fa 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -4,15 +4,31 @@ import pytorch_lightning as pl import torch +class EarlyStopWithoutVal(pl.callbacks.EarlyStopping): + """Run early stopping at the end of training loop. + + See: + https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html + + """ + def on_validation_end(self, trainer, pl_module): + # override this to disable early stopping at the end of val loop + pass + + def on_train_end(self, trainer, pl_module): + # instead, do it at the end of training loop + self._run_early_stopping_check(trainer, pl_module) + + class PruneLoserPrototypes(pl.Callback): def __init__(self, threshold=0.01, - prune_after_epochs=10, + idle_epochs=10, prune_quota_per_epoch=-1, frequency=1, verbose=False): self.threshold = threshold # minimum win ratio - self.prune_after_epochs = prune_after_epochs # epochs to wait + self.idle_epochs = idle_epochs # epochs to wait before pruning self.prune_quota_per_epoch = prune_quota_per_epoch self.frequency = frequency self.verbose = verbose @@ -21,7 +37,7 @@ class PruneLoserPrototypes(pl.Callback): pl_module.initialize_prototype_win_ratios() def on_epoch_end(self, trainer, pl_module): - if (trainer.current_epoch + 1) < self.prune_after_epochs: + if (trainer.current_epoch + 1) < self.idle_epochs: return None if (trainer.current_epoch + 1) % self.frequency: return None @@ -40,3 +56,18 @@ class PruneLoserPrototypes(pl.Callback): print(f"`num_prototypes` reduced from {cur_num_protos} " f"to {new_num_protos}.") return True + + +class PrototypeConvergence(pl.Callback): + def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False): + self.min_delta = min_delta + self.idle_epochs = idle_epochs # epochs to wait + self.verbose = verbose + + def on_epoch_end(self, trainer, pl_module): + if (trainer.current_epoch + 1) < self.idle_epochs: + return None + if self.verbose: + print("Stopping...") + # TODO + return True