fix: update hparams.distribution as it changes during training
				
					
				
			This commit is contained in:
		| @@ -3,12 +3,11 @@ | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| import torchmetrics | ||||
| from prototorch.core.initializers import ZerosCompInitializer | ||||
|  | ||||
| from ..core.competitions import WTAC | ||||
| from ..core.components import Components, LabeledComponents | ||||
| from ..core.distances import euclidean_distance | ||||
| from ..core.initializers import LabelsInitializer | ||||
| from ..core.initializers import LabelsInitializer, ZerosCompInitializer | ||||
| from ..core.pooling import stratified_min_pooling | ||||
| from ..nn.wrappers import LambdaLayer | ||||
|  | ||||
| @@ -76,10 +75,12 @@ class PrototypeModel(ProtoTorchBolt): | ||||
|  | ||||
|     def add_prototypes(self, *args, **kwargs): | ||||
|         self.proto_layer.add_components(*args, **kwargs) | ||||
|         self.hparams.distribution = self.proto_layer.distribution | ||||
|         self.reconfigure_optimizers() | ||||
|  | ||||
|     def remove_prototypes(self, indices): | ||||
|         self.proto_layer.remove_components(indices) | ||||
|         self.hparams.distribution = self.proto_layer.distribution | ||||
|         self.reconfigure_optimizers() | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user