[BUGFIX] Pruning example works on GPU now

This commit is contained in:
Alexander Engelsberger 2021-06-03 14:35:24 +02:00
parent 1b09b1d57b
commit e209bf73d5
2 changed files with 9 additions and 11 deletions

View File

@ -17,9 +17,6 @@ class PruneLoserPrototypes(pl.Callback):
self.frequency = frequency self.frequency = frequency
self.verbose = verbose self.verbose = verbose
def on_epoch_start(self, trainer, pl_module):
pl_module.initialize_prototype_win_ratios()
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.idle_epochs: if (trainer.current_epoch + 1) < self.idle_epochs:
return None return None

View File

@ -5,12 +5,9 @@ import torchmetrics
from prototorch.components import LabeledComponents from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import ( from prototorch.functions.distances import (euclidean_distance,
euclidean_distance, lomega_distance, omega_distance,
lomega_distance, squared_euclidean_distance)
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer from prototorch.modules import LambdaLayer
@ -97,8 +94,12 @@ class GLVQ(AbstractPrototypeModel):
logger=True) logger=True)
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.prototype_win_ratios = torch.zeros(self.num_prototypes, self.register_buffer(
device=self.device) "prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device))
def on_epoch_start(self):
self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances): def log_prototype_win_ratios(self, distances):
batch_size = len(distances) batch_size = len(distances)