[BUGFIX] Early stopping example works now

This commit is contained in:
Alexander Engelsberger
2021-06-03 13:38:16 +02:00
parent 64250d0938
commit 3b02d99ebe
3 changed files with 12 additions and 40 deletions

View File

@@ -2,22 +2,10 @@
from importlib.metadata import PackageNotFoundError, version
from .callbacks import (EarlyStopWithoutVal, PrototypeConvergence,
PruneLoserPrototypes)
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC
from .glvq import (
GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
from .unsupervised import KNN, NeuralGas

View File

@@ -4,22 +4,6 @@ import pytorch_lightning as pl
import torch
class EarlyStopWithoutVal(pl.callbacks.EarlyStopping):
"""Run early stopping at the end of training loop.
See:
https://pytorch-lightning.readthedocs.io/en/latest/common/early_stopping.html
"""
def on_validation_end(self, trainer, pl_module):
# override this to disable early stopping at the end of val loop
pass
def on_train_end(self, trainer, pl_module):
# instead, do it at the end of training loop
self._run_early_stopping_check(trainer, pl_module)
class PruneLoserPrototypes(pl.Callback):
def __init__(self,
threshold=0.01,