[BUGFIX] Pruning example works on GPU now
This commit is contained in:
		@@ -17,9 +17,6 @@ class PruneLoserPrototypes(pl.Callback):
 | 
				
			|||||||
        self.frequency = frequency
 | 
					        self.frequency = frequency
 | 
				
			||||||
        self.verbose = verbose
 | 
					        self.verbose = verbose
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def on_epoch_start(self, trainer, pl_module):
 | 
					 | 
				
			||||||
        pl_module.initialize_prototype_win_ratios()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        if (trainer.current_epoch + 1) < self.idle_epochs:
 | 
					        if (trainer.current_epoch + 1) < self.idle_epochs:
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,12 +5,9 @@ import torchmetrics
 | 
				
			|||||||
from prototorch.components import LabeledComponents
 | 
					from prototorch.components import LabeledComponents
 | 
				
			||||||
from prototorch.functions.activations import get_activation
 | 
					from prototorch.functions.activations import get_activation
 | 
				
			||||||
from prototorch.functions.competitions import wtac
 | 
					from prototorch.functions.competitions import wtac
 | 
				
			||||||
from prototorch.functions.distances import (
 | 
					from prototorch.functions.distances import (euclidean_distance,
 | 
				
			||||||
    euclidean_distance,
 | 
					                                            lomega_distance, omega_distance,
 | 
				
			||||||
    lomega_distance,
 | 
					                                            squared_euclidean_distance)
 | 
				
			||||||
    omega_distance,
 | 
					 | 
				
			||||||
    squared_euclidean_distance,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from prototorch.functions.helper import get_flat
 | 
					from prototorch.functions.helper import get_flat
 | 
				
			||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
					from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
				
			||||||
from prototorch.modules import LambdaLayer
 | 
					from prototorch.modules import LambdaLayer
 | 
				
			||||||
@@ -97,8 +94,12 @@ class GLVQ(AbstractPrototypeModel):
 | 
				
			|||||||
                 logger=True)
 | 
					                 logger=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def initialize_prototype_win_ratios(self):
 | 
					    def initialize_prototype_win_ratios(self):
 | 
				
			||||||
        self.prototype_win_ratios = torch.zeros(self.num_prototypes,
 | 
					        self.register_buffer(
 | 
				
			||||||
                                                device=self.device)
 | 
					            "prototype_win_ratios",
 | 
				
			||||||
 | 
					            torch.zeros(self.num_prototypes, device=self.device))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_epoch_start(self):
 | 
				
			||||||
 | 
					        self.initialize_prototype_win_ratios()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def log_prototype_win_ratios(self, distances):
 | 
					    def log_prototype_win_ratios(self, distances):
 | 
				
			||||||
        batch_size = len(distances)
 | 
					        batch_size = len(distances)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user