[BUGFIX] Early stopping example works now
This commit is contained in:
parent
64250d0938
commit
3b02d99ebe
@ -46,23 +46,23 @@ if __name__ == "__main__":
|
|||||||
vis = pt.models.VisGLVQ2D(train_ds)
|
vis = pt.models.VisGLVQ2D(train_ds)
|
||||||
pruning = pt.models.PruneLoserPrototypes(
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
threshold=0.01, # prune prototype if it wins less than 1%
|
threshold=0.01, # prune prototype if it wins less than 1%
|
||||||
idle_epochs=30, # pruning too early may cause problems
|
idle_epochs=10, # pruning too early may cause problems
|
||||||
prune_quota_per_epoch=1, # prune at most 1 prototype per epoch
|
prune_quota_per_epoch=5, # prune at most 5 prototypes per epoch
|
||||||
frequency=5, # prune every fifth epoch
|
frequency=2, # prune every second epoch
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
es = pt.models.EarlyStopWithoutVal(
|
|
||||||
monitor="loss",
|
es = pl.callbacks.EarlyStopping(
|
||||||
min_delta=0.1,
|
monitor="train_loss",
|
||||||
patience=3,
|
min_delta=0.001,
|
||||||
|
patience=15,
|
||||||
mode="min",
|
mode="min",
|
||||||
verbose=True,
|
check_on_train_epoch_end=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
max_epochs=250,
|
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
pruning,
|
pruning,
|
||||||
|
@ -2,22 +2,10 @@
|
|||||||
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
from .callbacks import (EarlyStopWithoutVal, PrototypeConvergence,
|
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||||
PruneLoserPrototypes)
|
|
||||||
from .cbc import CBC, ImageCBC
|
from .cbc import CBC, ImageCBC
|
||||||
from .glvq import (
|
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
||||||
GLVQ,
|
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
||||||
GLVQ1,
|
|
||||||
GLVQ21,
|
|
||||||
GMLVQ,
|
|
||||||
GRLVQ,
|
|
||||||
LGMLVQ,
|
|
||||||
LVQMLN,
|
|
||||||
ImageGLVQ,
|
|
||||||
ImageGMLVQ,
|
|
||||||
SiameseGLVQ,
|
|
||||||
SiameseGMLVQ,
|
|
||||||
)
|
|
||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||||
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
||||||
from .unsupervised import KNN, NeuralGas
|
from .unsupervised import KNN, NeuralGas
|
||||||
|
@ -4,22 +4,6 @@ import pytorch_lightning as pl
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopWithoutVal(pl.callbacks.EarlyStopping):
|
|
||||||
"""Run early stopping at the end of training loop.
|
|
||||||
|
|
||||||
See:
|
|
||||||
https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
|
|
||||||
|
|
||||||
"""
|
|
||||||
def on_validation_end(self, trainer, pl_module):
|
|
||||||
# override this to disable early stopping at the end of val loop
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_train_end(self, trainer, pl_module):
|
|
||||||
# instead, do it at the end of training loop
|
|
||||||
self._run_early_stopping_check(trainer, pl_module)
|
|
||||||
|
|
||||||
|
|
||||||
class PruneLoserPrototypes(pl.Callback):
|
class PruneLoserPrototypes(pl.Callback):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
threshold=0.01,
|
threshold=0.01,
|
||||||
|
Loading…
Reference in New Issue
Block a user