[REFACTOR] Update examples/dynamic_pruning.py

This commit is contained in:
Jensun Ravichandran 2021-06-02 03:53:21 +02:00
parent 8851d1bbc9
commit bdacc83185

View File

@ -5,33 +5,6 @@ import argparse
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from pytorch_lightning.callbacks import Callback
class PrototypePruning(Callback):
def __init__(self, threshold=0.01, prune_after=10, verbose=False):
self.threshold = threshold
self.prune_after = prune_after
self.verbose = verbose
def on_epoch_start(self, trainer, pl_module):
pl_module.initialize_prototype_win_ratios()
def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) > self.prune_after:
ratios = pl_module.prototype_win_ratios.mean(dim=0)
to_prune = torch.arange(len(ratios))[ratios < self.threshold]
if len(to_prune) > 0:
if self.verbose:
print(f"\nPrototype win ratios: {ratios}")
print(f"Pruning prototypes at indices: {to_prune}")
cur_num_protos = pl_module.num_prototypes
pl_module.remove_prototypes(indices=to_prune)
new_num_protos = pl_module.num_prototypes
if self.verbose:
print(f"`num_prototypes` reduced from {cur_num_protos} "
f"to {new_num_protos}.")
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -71,9 +44,11 @@ if __name__ == "__main__":
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(train_ds) vis = pt.models.VisGLVQ2D(train_ds)
pruning = PrototypePruning( pruning = pt.models.PruneLoserPrototypes(
threshold=0.01, # prune prototype if it wins less than 1% threshold=0.01, # prune prototype if it wins less than 1%
prune_after=50, prune_after_epochs=30, # pruning too early may cause problems
prune_quota_per_epoch=1, # prune at most 1 prototype per epoch
frequency=5, # prune every fifth epoch
verbose=True, verbose=True,
) )