[BUGFIX] examples/glvq_spiral.py
works again
This commit is contained in:
@@ -16,7 +16,7 @@ class PruneLoserPrototypes(pl.Callback):
|
||||
prune_quota_per_epoch=-1,
|
||||
frequency=1,
|
||||
replace=False,
|
||||
initializer=None,
|
||||
prototypes_initializer=None,
|
||||
verbose=False):
|
||||
self.threshold = threshold # minimum win ratio
|
||||
self.idle_epochs = idle_epochs # epochs to wait before pruning
|
||||
@@ -24,7 +24,7 @@ class PruneLoserPrototypes(pl.Callback):
|
||||
self.frequency = frequency
|
||||
self.replace = replace
|
||||
self.verbose = verbose
|
||||
self.initializer = initializer
|
||||
self.prototypes_initializer = prototypes_initializer
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if (trainer.current_epoch + 1) < self.idle_epochs:
|
||||
@@ -55,8 +55,9 @@ class PruneLoserPrototypes(pl.Callback):
|
||||
if self.verbose:
|
||||
print(f"Re-adding pruned prototypes...")
|
||||
print(f"{distribution=}")
|
||||
pl_module.add_prototypes(distribution=distribution,
|
||||
initializer=self.initializer)
|
||||
pl_module.add_prototypes(
|
||||
distribution=distribution,
|
||||
components_initializer=self.prototypes_initializer)
|
||||
new_num_protos = pl_module.num_prototypes
|
||||
if self.verbose:
|
||||
print(f"`num_prototypes` changed from {cur_num_protos} "
|
||||
|
Reference in New Issue
Block a user