diff --git a/examples/dynamic_pruning.py b/examples/dynamic_pruning.py index ee282b4..e9f8946 100644 --- a/examples/dynamic_pruning.py +++ b/examples/dynamic_pruning.py @@ -5,33 +5,6 @@ import argparse import prototorch as pt import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks import Callback - - -class PrototypePruning(Callback): - def __init__(self, threshold=0.01, prune_after=10, verbose=False): - self.threshold = threshold - self.prune_after = prune_after - self.verbose = verbose - - def on_epoch_start(self, trainer, pl_module): - pl_module.initialize_prototype_win_ratios() - - def on_epoch_end(self, trainer, pl_module): - if (trainer.current_epoch + 1) > self.prune_after: - ratios = pl_module.prototype_win_ratios.mean(dim=0) - to_prune = torch.arange(len(ratios))[ratios < self.threshold] - if len(to_prune) > 0: - if self.verbose: - print(f"\nPrototype win ratios: {ratios}") - print(f"Pruning prototypes at indices: {to_prune}") - cur_num_protos = pl_module.num_prototypes - pl_module.remove_prototypes(indices=to_prune) - new_num_protos = pl_module.num_prototypes - if self.verbose: - print(f"`num_prototypes` reduced from {cur_num_protos} " - f"to {new_num_protos}.") - if __name__ == "__main__": # Command-line arguments @@ -71,9 +44,11 @@ if __name__ == "__main__": # Callbacks vis = pt.models.VisGLVQ2D(train_ds) - pruning = PrototypePruning( + pruning = pt.models.PruneLoserPrototypes( threshold=0.01, # prune prototype if it wins less than 1% - prune_after=50, + prune_after_epochs=30, # pruning too early may cause problems + prune_quota_per_epoch=1, # prune at most 1 prototype per epoch + frequency=5, # prune every fifth epoch verbose=True, )