[BUGFIX] Pruning example works on GPU now
This commit is contained in:
parent
1b09b1d57b
commit
e209bf73d5
@ -17,9 +17,6 @@ class PruneLoserPrototypes(pl.Callback):
|
|||||||
self.frequency = frequency
|
self.frequency = frequency
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
def on_epoch_start(self, trainer, pl_module):
|
|
||||||
pl_module.initialize_prototype_win_ratios()
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
if (trainer.current_epoch + 1) < self.idle_epochs:
|
if (trainer.current_epoch + 1) < self.idle_epochs:
|
||||||
return None
|
return None
|
||||||
|
@ -5,12 +5,9 @@ import torchmetrics
|
|||||||
from prototorch.components import LabeledComponents
|
from prototorch.components import LabeledComponents
|
||||||
from prototorch.functions.activations import get_activation
|
from prototorch.functions.activations import get_activation
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import (
|
from prototorch.functions.distances import (euclidean_distance,
|
||||||
euclidean_distance,
|
lomega_distance, omega_distance,
|
||||||
lomega_distance,
|
squared_euclidean_distance)
|
||||||
omega_distance,
|
|
||||||
squared_euclidean_distance,
|
|
||||||
)
|
|
||||||
from prototorch.functions.helper import get_flat
|
from prototorch.functions.helper import get_flat
|
||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
from prototorch.modules import LambdaLayer
|
from prototorch.modules import LambdaLayer
|
||||||
@ -97,8 +94,12 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
logger=True)
|
logger=True)
|
||||||
|
|
||||||
def initialize_prototype_win_ratios(self):
|
def initialize_prototype_win_ratios(self):
|
||||||
self.prototype_win_ratios = torch.zeros(self.num_prototypes,
|
self.register_buffer(
|
||||||
device=self.device)
|
"prototype_win_ratios",
|
||||||
|
torch.zeros(self.num_prototypes, device=self.device))
|
||||||
|
|
||||||
|
def on_epoch_start(self):
|
||||||
|
self.initialize_prototype_win_ratios()
|
||||||
|
|
||||||
def log_prototype_win_ratios(self, distances):
|
def log_prototype_win_ratios(self, distances):
|
||||||
batch_size = len(distances)
|
batch_size = len(distances)
|
||||||
|
Loading…
Reference in New Issue
Block a user