[REFACTOR] Update examples/dynamic_pruning.py
This commit is contained in:
parent
8851d1bbc9
commit
bdacc83185
@ -5,33 +5,6 @@ import argparse
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning.callbacks import Callback
|
|
||||||
|
|
||||||
|
|
||||||
class PrototypePruning(Callback):
|
|
||||||
def __init__(self, threshold=0.01, prune_after=10, verbose=False):
|
|
||||||
self.threshold = threshold
|
|
||||||
self.prune_after = prune_after
|
|
||||||
self.verbose = verbose
|
|
||||||
|
|
||||||
def on_epoch_start(self, trainer, pl_module):
|
|
||||||
pl_module.initialize_prototype_win_ratios()
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
|
||||||
if (trainer.current_epoch + 1) > self.prune_after:
|
|
||||||
ratios = pl_module.prototype_win_ratios.mean(dim=0)
|
|
||||||
to_prune = torch.arange(len(ratios))[ratios < self.threshold]
|
|
||||||
if len(to_prune) > 0:
|
|
||||||
if self.verbose:
|
|
||||||
print(f"\nPrototype win ratios: {ratios}")
|
|
||||||
print(f"Pruning prototypes at indices: {to_prune}")
|
|
||||||
cur_num_protos = pl_module.num_prototypes
|
|
||||||
pl_module.remove_prototypes(indices=to_prune)
|
|
||||||
new_num_protos = pl_module.num_prototypes
|
|
||||||
if self.verbose:
|
|
||||||
print(f"`num_prototypes` reduced from {cur_num_protos} "
|
|
||||||
f"to {new_num_protos}.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
@ -71,9 +44,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(train_ds)
|
vis = pt.models.VisGLVQ2D(train_ds)
|
||||||
pruning = PrototypePruning(
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
threshold=0.01, # prune prototype if it wins less than 1%
|
threshold=0.01, # prune prototype if it wins less than 1%
|
||||||
prune_after=50,
|
prune_after_epochs=30, # pruning too early may cause problems
|
||||||
|
prune_quota_per_epoch=1, # prune at most 1 prototype per epoch
|
||||||
|
frequency=5, # prune every fifth epoch
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user