feat: add early-stopping and pruning to examples/warm_starting.py
This commit is contained in:
parent
09e3ef1d0e
commit
0f9f24e36a
@ -37,7 +37,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Setup trainer for GNG
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=200,
|
||||
max_epochs=100,
|
||||
callbacks=[es],
|
||||
weights_summary=None,
|
||||
)
|
||||
@ -71,11 +71,30 @@ if __name__ == "__main__":
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.02,
|
||||
idle_epochs=2,
|
||||
prune_quota_per_epoch=5,
|
||||
frequency=1,
|
||||
verbose=True,
|
||||
)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=0.001,
|
||||
patience=10,
|
||||
mode="min",
|
||||
verbose=True,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
callbacks=[
|
||||
vis,
|
||||
pruning,
|
||||
es,
|
||||
],
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user