Add example for dynamic components in callbacks

This commit is contained in:
Alexander Engelsberger 2021-05-31 11:39:24 +02:00
parent db064b5af1
commit 2a218c0ede
2 changed files with 28 additions and 21 deletions

View File

@ -2,10 +2,24 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import Callback
class PrototypeScheduler(Callback):
def __init__(self, train_ds, freq=20):
self.train_ds = train_ds
self.freq = freq
def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) % self.freq == 0:
pl_module.increase_prototypes(
pt.components.SMI(self.train_ds),
distribution=[1, 1, 1],
)
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
@ -33,24 +47,17 @@ if __name__ == "__main__":
prototype_initializer=pt.components.SMI(train_ds),
)
for _ in range(5):
# Callbacks
vis = pt.models.VisGLVQ2D(train_ds)
# Callbacks
vis = pt.models.VisGLVQ2D(train_ds)
proto_scheduler = PrototypeScheduler(train_ds, 10)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=20,
callbacks=[vis],
terminate_on_nan=True,
weights_summary=None,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(args,
max_epochs=100,
callbacks=[vis, proto_scheduler],
terminate_on_nan=True,
weights_summary=None,
accelerator='ddp')
# Training loop
trainer.fit(model, train_loader)
# Increase prototypes
model.increase_prototypes(
pt.components.SMI(train_ds),
distribution=[1, 1, 1],
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -137,7 +137,7 @@ class GLVQ(AbstractPrototypeModel):
def increase_prototypes(self, initializer, distribution):
self.proto_layer.increase_components(initializer, distribution)
#self.trainer.accelerated_backend.setup_optimizers(self)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def __repr__(self):
super_repr = super().__repr__()