[BUG] Early stopping does not seem to work

The early stopping callback does not work as expected, and crashes at the end of
max_epochs with:

```
~/miniconda3/envs/py38/lib/python3.8/site-packages/pytorch_lightning/trainer/callback_hook.py in on_train_end(self)
    155         """Called when the train ends."""
    156         for callback in self.callbacks:
--> 157             callback.on_train_end(self, self.lightning_module)
    158
    159     def on_pretrain_routine_start(self) -> None:

~/work/repos/prototorch_models/prototorch/models/callbacks.py in on_train_end(self, trainer, pl_module)
     18     def on_train_end(self, trainer, pl_module):
     19         # instead, do it at the end of training loop
---> 20         self._run_early_stopping_check(trainer, pl_module)
     21
     22

TypeError: _run_early_stopping_check() takes 2 positional arguments but 3 were given
```
This commit is contained in:
Jensun Ravichandran 2021-06-02 12:44:34 +02:00
parent bdacc83185
commit ef6bcc1079
3 changed files with 46 additions and 6 deletions

View File

@ -46,19 +46,27 @@ if __name__ == "__main__":
vis = pt.models.VisGLVQ2D(train_ds) vis = pt.models.VisGLVQ2D(train_ds)
pruning = pt.models.PruneLoserPrototypes( pruning = pt.models.PruneLoserPrototypes(
threshold=0.01, # prune prototype if it wins less than 1% threshold=0.01, # prune prototype if it wins less than 1%
prune_after_epochs=30, # pruning too early may cause problems idle_epochs=30, # pruning too early may cause problems
prune_quota_per_epoch=1, # prune at most 1 prototype per epoch prune_quota_per_epoch=1, # prune at most 1 prototype per epoch
frequency=5, # prune every fifth epoch frequency=5, # prune every fifth epoch
verbose=True, verbose=True,
) )
es = pt.models.EarlyStopWithoutVal(
monitor="loss",
min_delta=0.1,
patience=3,
mode="min",
verbose=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
max_epochs=100, max_epochs=250,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,
es,
], ],
terminate_on_nan=True, terminate_on_nan=True,
weights_summary=None, weights_summary=None,

View File

@ -2,7 +2,8 @@
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from .callbacks import PruneLoserPrototypes from .callbacks import (EarlyStopWithoutVal, PrototypeConvergence,
PruneLoserPrototypes)
from .cbc import CBC, ImageCBC from .cbc import CBC, ImageCBC
from .glvq import ( from .glvq import (
GLVQ, GLVQ,

View File

@ -4,15 +4,31 @@ import pytorch_lightning as pl
import torch import torch
class EarlyStopWithoutVal(pl.callbacks.EarlyStopping):
"""Run early stopping at the end of training loop.
See:
https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
"""
def on_validation_end(self, trainer, pl_module):
# override this to disable early stopping at the end of val loop
pass
def on_train_end(self, trainer, pl_module):
# instead, do it at the end of training loop
self._run_early_stopping_check(trainer, pl_module)
class PruneLoserPrototypes(pl.Callback): class PruneLoserPrototypes(pl.Callback):
def __init__(self, def __init__(self,
threshold=0.01, threshold=0.01,
prune_after_epochs=10, idle_epochs=10,
prune_quota_per_epoch=-1, prune_quota_per_epoch=-1,
frequency=1, frequency=1,
verbose=False): verbose=False):
self.threshold = threshold # minimum win ratio self.threshold = threshold # minimum win ratio
self.prune_after_epochs = prune_after_epochs # epochs to wait self.idle_epochs = idle_epochs # epochs to wait before pruning
self.prune_quota_per_epoch = prune_quota_per_epoch self.prune_quota_per_epoch = prune_quota_per_epoch
self.frequency = frequency self.frequency = frequency
self.verbose = verbose self.verbose = verbose
@ -21,7 +37,7 @@ class PruneLoserPrototypes(pl.Callback):
pl_module.initialize_prototype_win_ratios() pl_module.initialize_prototype_win_ratios()
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.prune_after_epochs: if (trainer.current_epoch + 1) < self.idle_epochs:
return None return None
if (trainer.current_epoch + 1) % self.frequency: if (trainer.current_epoch + 1) % self.frequency:
return None return None
@ -40,3 +56,18 @@ class PruneLoserPrototypes(pl.Callback):
print(f"`num_prototypes` reduced from {cur_num_protos} " print(f"`num_prototypes` reduced from {cur_num_protos} "
f"to {new_num_protos}.") f"to {new_num_protos}.")
return True return True
class PrototypeConvergence(pl.Callback):
def __init__(self, min_delta=0.01, idle_epochs=10, verbose=False):
self.min_delta = min_delta
self.idle_epochs = idle_epochs # epochs to wait
self.verbose = verbose
def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.idle_epochs:
return None
if self.verbose:
print("Stopping...")
# TODO
return True