Add example for dynamic components in callbacks

This commit is contained in:
Alexander Engelsberger 2021-05-31 11:39:24 +02:00
parent db064b5af1
commit 2a218c0ede
2 changed files with 28 additions and 21 deletions

View File

@ -2,10 +2,24 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from pytorch_lightning.callbacks import Callback
class PrototypeScheduler(Callback):
def __init__(self, train_ds, freq=20):
self.train_ds = train_ds
self.freq = freq
def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) % self.freq == 0:
pl_module.increase_prototypes(
pt.components.SMI(self.train_ds),
distribution=[1, 1, 1],
)
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -33,24 +47,17 @@ if __name__ == "__main__":
prototype_initializer=pt.components.SMI(train_ds), prototype_initializer=pt.components.SMI(train_ds),
) )
for _ in range(5): # Callbacks
# Callbacks vis = pt.models.VisGLVQ2D(train_ds)
vis = pt.models.VisGLVQ2D(train_ds) proto_scheduler = PrototypeScheduler(train_ds, 10)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(args,
args, max_epochs=100,
max_epochs=20, callbacks=[vis, proto_scheduler],
callbacks=[vis], terminate_on_nan=True,
terminate_on_nan=True, weights_summary=None,
weights_summary=None, accelerator='ddp')
)
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
# Increase prototypes
model.increase_prototypes(
pt.components.SMI(train_ds),
distribution=[1, 1, 1],
)

View File

@ -137,7 +137,7 @@ class GLVQ(AbstractPrototypeModel):
def increase_prototypes(self, initializer, distribution): def increase_prototypes(self, initializer, distribution):
self.proto_layer.increase_components(initializer, distribution) self.proto_layer.increase_components(initializer, distribution)
#self.trainer.accelerated_backend.setup_optimizers(self) self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def __repr__(self): def __repr__(self):
super_repr = super().__repr__() super_repr = super().__repr__()