[BUGFIX] Early stopping example works now
This commit is contained in:
parent
64250d0938
commit
3b02d99ebe
@ -46,23 +46,23 @@ if __name__ == "__main__":
|
||||
vis = pt.models.VisGLVQ2D(train_ds)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.01, # prune prototype if it wins less than 1%
|
||||
idle_epochs=30, # pruning too early may cause problems
|
||||
prune_quota_per_epoch=1, # prune at most 1 prototype per epoch
|
||||
frequency=5, # prune every fifth epoch
|
||||
idle_epochs=10, # pruning too early may cause problems
|
||||
prune_quota_per_epoch=5, # prune at most 5 prototypes per epoch
|
||||
frequency=2, # prune every second epoch
|
||||
verbose=True,
|
||||
)
|
||||
es = pt.models.EarlyStopWithoutVal(
|
||||
monitor="loss",
|
||||
min_delta=0.1,
|
||||
patience=3,
|
||||
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=0.001,
|
||||
patience=15,
|
||||
mode="min",
|
||||
verbose=True,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
max_epochs=250,
|
||||
callbacks=[
|
||||
vis,
|
||||
pruning,
|
||||
|
@ -2,22 +2,10 @@
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from .callbacks import (EarlyStopWithoutVal, PrototypeConvergence,
|
||||
PruneLoserPrototypes)
|
||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||
from .cbc import CBC, ImageCBC
|
||||
from .glvq import (
|
||||
GLVQ,
|
||||
GLVQ1,
|
||||
GLVQ21,
|
||||
GMLVQ,
|
||||
GRLVQ,
|
||||
LGMLVQ,
|
||||
LVQMLN,
|
||||
ImageGLVQ,
|
||||
ImageGMLVQ,
|
||||
SiameseGLVQ,
|
||||
SiameseGMLVQ,
|
||||
)
|
||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
||||
from .unsupervised import KNN, NeuralGas
|
||||
|
@ -4,22 +4,6 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class EarlyStopWithoutVal(pl.callbacks.EarlyStopping):
|
||||
"""Run early stopping at the end of training loop.
|
||||
|
||||
See:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
|
||||
|
||||
"""
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
# override this to disable early stopping at the end of val loop
|
||||
pass
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
# instead, do it at the end of training loop
|
||||
self._run_early_stopping_check(trainer, pl_module)
|
||||
|
||||
|
||||
class PruneLoserPrototypes(pl.Callback):
|
||||
def __init__(self,
|
||||
threshold=0.01,
|
||||
|
Loading…
Reference in New Issue
Block a user