[REFACTOR] Update examples/dynamic_pruning.py
This commit is contained in:
parent
8851d1bbc9
commit
bdacc83185
@ -5,33 +5,6 @@ import argparse
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
|
||||
class PrototypePruning(Callback):
|
||||
def __init__(self, threshold=0.01, prune_after=10, verbose=False):
|
||||
self.threshold = threshold
|
||||
self.prune_after = prune_after
|
||||
self.verbose = verbose
|
||||
|
||||
def on_epoch_start(self, trainer, pl_module):
|
||||
pl_module.initialize_prototype_win_ratios()
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if (trainer.current_epoch + 1) > self.prune_after:
|
||||
ratios = pl_module.prototype_win_ratios.mean(dim=0)
|
||||
to_prune = torch.arange(len(ratios))[ratios < self.threshold]
|
||||
if len(to_prune) > 0:
|
||||
if self.verbose:
|
||||
print(f"\nPrototype win ratios: {ratios}")
|
||||
print(f"Pruning prototypes at indices: {to_prune}")
|
||||
cur_num_protos = pl_module.num_prototypes
|
||||
pl_module.remove_prototypes(indices=to_prune)
|
||||
new_num_protos = pl_module.num_prototypes
|
||||
if self.verbose:
|
||||
print(f"`num_prototypes` reduced from {cur_num_protos} "
|
||||
f"to {new_num_protos}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
@ -71,9 +44,11 @@ if __name__ == "__main__":
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(train_ds)
|
||||
pruning = PrototypePruning(
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.01, # prune prototype if it wins less than 1%
|
||||
prune_after=50,
|
||||
prune_after_epochs=30, # pruning too early may cause problems
|
||||
prune_quota_per_epoch=1, # prune at most 1 prototype per epoch
|
||||
frequency=5, # prune every fifth epoch
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user