fix: update hparams.distribution as it changes during training
				
					
				
			This commit is contained in:
		@@ -3,12 +3,11 @@
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
from prototorch.core.initializers import ZerosCompInitializer
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.competitions import WTAC
 | 
					from ..core.competitions import WTAC
 | 
				
			||||||
from ..core.components import Components, LabeledComponents
 | 
					from ..core.components import Components, LabeledComponents
 | 
				
			||||||
from ..core.distances import euclidean_distance
 | 
					from ..core.distances import euclidean_distance
 | 
				
			||||||
from ..core.initializers import LabelsInitializer
 | 
					from ..core.initializers import LabelsInitializer, ZerosCompInitializer
 | 
				
			||||||
from ..core.pooling import stratified_min_pooling
 | 
					from ..core.pooling import stratified_min_pooling
 | 
				
			||||||
from ..nn.wrappers import LambdaLayer
 | 
					from ..nn.wrappers import LambdaLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -76,10 +75,12 @@ class PrototypeModel(ProtoTorchBolt):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def add_prototypes(self, *args, **kwargs):
 | 
					    def add_prototypes(self, *args, **kwargs):
 | 
				
			||||||
        self.proto_layer.add_components(*args, **kwargs)
 | 
					        self.proto_layer.add_components(*args, **kwargs)
 | 
				
			||||||
 | 
					        self.hparams.distribution = self.proto_layer.distribution
 | 
				
			||||||
        self.reconfigure_optimizers()
 | 
					        self.reconfigure_optimizers()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def remove_prototypes(self, indices):
 | 
					    def remove_prototypes(self, indices):
 | 
				
			||||||
        self.proto_layer.remove_components(indices)
 | 
					        self.proto_layer.remove_components(indices)
 | 
				
			||||||
 | 
					        self.hparams.distribution = self.proto_layer.distribution
 | 
				
			||||||
        self.reconfigure_optimizers()
 | 
					        self.reconfigure_optimizers()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user