[BUGFIX] Early stopping example works now

This commit is contained in:
Alexander Engelsberger 2021-06-03 13:38:16 +02:00
parent 64250d0938
commit 3b02d99ebe
3 changed files with 12 additions and 40 deletions

View File

@ -46,23 +46,23 @@ if __name__ == "__main__":
vis = pt.models.VisGLVQ2D(train_ds) vis = pt.models.VisGLVQ2D(train_ds)
pruning = pt.models.PruneLoserPrototypes( pruning = pt.models.PruneLoserPrototypes(
threshold=0.01, # prune prototype if it wins less than 1% threshold=0.01, # prune prototype if it wins less than 1%
idle_epochs=30, # pruning too early may cause problems idle_epochs=10, # pruning too early may cause problems
prune_quota_per_epoch=1, # prune at most 1 prototype per epoch prune_quota_per_epoch=5, # prune at most 5 prototypes per epoch
frequency=5, # prune every fifth epoch frequency=2, # prune every second epoch
verbose=True, verbose=True,
) )
es = pt.models.EarlyStopWithoutVal(
monitor="loss", es = pl.callbacks.EarlyStopping(
min_delta=0.1, monitor="train_loss",
patience=3, min_delta=0.001,
patience=15,
mode="min", mode="min",
verbose=True, check_on_train_epoch_end=True,
) )
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
max_epochs=250,
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,

View File

@ -2,22 +2,10 @@
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from .callbacks import (EarlyStopWithoutVal, PrototypeConvergence, from .callbacks import PrototypeConvergence, PruneLoserPrototypes
PruneLoserPrototypes)
from .cbc import CBC, ImageCBC from .cbc import CBC, ImageCBC
from .glvq import ( from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
GLVQ, ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .lvq import LVQ1, LVQ21, MedianLVQ from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
from .unsupervised import KNN, NeuralGas from .unsupervised import KNN, NeuralGas

View File

@ -4,22 +4,6 @@ import pytorch_lightning as pl
import torch import torch
class EarlyStopWithoutVal(pl.callbacks.EarlyStopping):
"""Run early stopping at the end of training loop.
See:
https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
"""
def on_validation_end(self, trainer, pl_module):
# override this to disable early stopping at the end of val loop
pass
def on_train_end(self, trainer, pl_module):
# instead, do it at the end of training loop
self._run_early_stopping_check(trainer, pl_module)
class PruneLoserPrototypes(pl.Callback): class PruneLoserPrototypes(pl.Callback):
def __init__(self, def __init__(self,
threshold=0.01, threshold=0.01,