[BUGFIX] Early stopping example works now

This commit is contained in:
Alexander Engelsberger 2021-06-03 13:38:16 +02:00
parent 64250d0938
commit 3b02d99ebe
3 changed files with 12 additions and 40 deletions

View File

@ -46,23 +46,23 @@ if __name__ == "__main__":
vis = pt.models.VisGLVQ2D(train_ds)
pruning = pt.models.PruneLoserPrototypes(
threshold=0.01, # prune prototype if it wins less than 1%
idle_epochs=30, # pruning too early may cause problems
prune_quota_per_epoch=1, # prune at most 1 prototype per epoch
frequency=5, # prune every fifth epoch
idle_epochs=10, # pruning too early may cause problems
prune_quota_per_epoch=5, # prune at most 5 prototypes per epoch
frequency=2, # prune every second epoch
verbose=True,
)
es = pt.models.EarlyStopWithoutVal(
monitor="loss",
min_delta=0.1,
patience=3,
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=15,
mode="min",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=250,
callbacks=[
vis,
pruning,

View File

@ -2,22 +2,10 @@
from importlib.metadata import PackageNotFoundError, version
from .callbacks import (EarlyStopWithoutVal, PrototypeConvergence,
PruneLoserPrototypes)
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC
from .glvq import (
GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
from .unsupervised import KNN, NeuralGas

View File

@ -4,22 +4,6 @@ import pytorch_lightning as pl
import torch
class EarlyStopWithoutVal(pl.callbacks.EarlyStopping):
"""Run early stopping at the end of training loop.
See:
https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
"""
def on_validation_end(self, trainer, pl_module):
# override this to disable early stopping at the end of val loop
pass
def on_train_end(self, trainer, pl_module):
# instead, do it at the end of training loop
self._run_early_stopping_check(trainer, pl_module)
class PruneLoserPrototypes(pl.Callback):
def __init__(self,
threshold=0.01,