[BUGFIX] Pruning example works on GPU now
This commit is contained in:
parent
1b09b1d57b
commit
e209bf73d5
@ -17,9 +17,6 @@ class PruneLoserPrototypes(pl.Callback):
|
||||
self.frequency = frequency
|
||||
self.verbose = verbose
|
||||
|
||||
def on_epoch_start(self, trainer, pl_module):
|
||||
pl_module.initialize_prototype_win_ratios()
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if (trainer.current_epoch + 1) < self.idle_epochs:
|
||||
return None
|
||||
|
@ -5,12 +5,9 @@ import torchmetrics
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (
|
||||
euclidean_distance,
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from prototorch.functions.distances import (euclidean_distance,
|
||||
lomega_distance, omega_distance,
|
||||
squared_euclidean_distance)
|
||||
from prototorch.functions.helper import get_flat
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from prototorch.modules import LambdaLayer
|
||||
@ -97,8 +94,12 @@ class GLVQ(AbstractPrototypeModel):
|
||||
logger=True)
|
||||
|
||||
def initialize_prototype_win_ratios(self):
|
||||
self.prototype_win_ratios = torch.zeros(self.num_prototypes,
|
||||
device=self.device)
|
||||
self.register_buffer(
|
||||
"prototype_win_ratios",
|
||||
torch.zeros(self.num_prototypes, device=self.device))
|
||||
|
||||
def on_epoch_start(self):
|
||||
self.initialize_prototype_win_ratios()
|
||||
|
||||
def log_prototype_win_ratios(self, distances):
|
||||
batch_size = len(distances)
|
||||
|
Loading…
Reference in New Issue
Block a user