[BUGFIX] Pruning example works on GPU now

This commit is contained in:
Alexander Engelsberger 2021-06-03 14:35:24 +02:00
parent 1b09b1d57b
commit e209bf73d5
2 changed files with 9 additions and 11 deletions

View File

@ -17,9 +17,6 @@ class PruneLoserPrototypes(pl.Callback):
self.frequency = frequency
self.verbose = verbose
def on_epoch_start(self, trainer, pl_module):
pl_module.initialize_prototype_win_ratios()
def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.idle_epochs:
return None

View File

@ -5,12 +5,9 @@ import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (
euclidean_distance,
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.distances import (euclidean_distance,
lomega_distance, omega_distance,
squared_euclidean_distance)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer
@ -97,8 +94,12 @@ class GLVQ(AbstractPrototypeModel):
logger=True)
def initialize_prototype_win_ratios(self):
self.prototype_win_ratios = torch.zeros(self.num_prototypes,
device=self.device)
self.register_buffer(
"prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device))
def on_epoch_start(self):
self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances):
batch_size = len(distances)