[BUGFIX] Early stopping example works now
This commit is contained in:
@@ -2,22 +2,10 @@
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from .callbacks import (EarlyStopWithoutVal, PrototypeConvergence,
|
||||
PruneLoserPrototypes)
|
||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||
from .cbc import CBC, ImageCBC
|
||||
from .glvq import (
|
||||
GLVQ,
|
||||
GLVQ1,
|
||||
GLVQ21,
|
||||
GMLVQ,
|
||||
GRLVQ,
|
||||
LGMLVQ,
|
||||
LVQMLN,
|
||||
ImageGLVQ,
|
||||
ImageGMLVQ,
|
||||
SiameseGLVQ,
|
||||
SiameseGMLVQ,
|
||||
)
|
||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
||||
from .unsupervised import KNN, NeuralGas
|
||||
|
@@ -4,22 +4,6 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class EarlyStopWithoutVal(pl.callbacks.EarlyStopping):
|
||||
"""Run early stopping at the end of training loop.
|
||||
|
||||
See:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
|
||||
|
||||
"""
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
# override this to disable early stopping at the end of val loop
|
||||
pass
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
# instead, do it at the end of training loop
|
||||
self._run_early_stopping_check(trainer, pl_module)
|
||||
|
||||
|
||||
class PruneLoserPrototypes(pl.Callback):
|
||||
def __init__(self,
|
||||
threshold=0.01,
|
||||
|
Reference in New Issue
Block a user