feat: add early-stopping and pruning to examples/warm_starting.py
This commit is contained in:
parent
09e3ef1d0e
commit
0f9f24e36a
@ -37,7 +37,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Setup trainer for GNG
|
# Setup trainer for GNG
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=200,
|
max_epochs=100,
|
||||||
callbacks=[es],
|
callbacks=[es],
|
||||||
weights_summary=None,
|
weights_summary=None,
|
||||||
)
|
)
|
||||||
@ -71,11 +71,30 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
|
threshold=0.02,
|
||||||
|
idle_epochs=2,
|
||||||
|
prune_quota_per_epoch=5,
|
||||||
|
frequency=1,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
es = pl.callbacks.EarlyStopping(
|
||||||
|
monitor="train_loss",
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=10,
|
||||||
|
mode="min",
|
||||||
|
verbose=True,
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
callbacks=[vis],
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
pruning,
|
||||||
|
es,
|
||||||
|
],
|
||||||
weights_summary="full",
|
weights_summary="full",
|
||||||
accelerator="ddp",
|
accelerator="ddp",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user