fix: update hparams.distribution
as it changes during training
This commit is contained in:
parent
15e7232747
commit
dd696ea1e0
@ -3,12 +3,11 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.core.initializers import ZerosCompInitializer
|
||||
|
||||
from ..core.competitions import WTAC
|
||||
from ..core.components import Components, LabeledComponents
|
||||
from ..core.distances import euclidean_distance
|
||||
from ..core.initializers import LabelsInitializer
|
||||
from ..core.initializers import LabelsInitializer, ZerosCompInitializer
|
||||
from ..core.pooling import stratified_min_pooling
|
||||
from ..nn.wrappers import LambdaLayer
|
||||
|
||||
@ -76,10 +75,12 @@ class PrototypeModel(ProtoTorchBolt):
|
||||
|
||||
def add_prototypes(self, *args, **kwargs):
|
||||
self.proto_layer.add_components(*args, **kwargs)
|
||||
self.hparams.distribution = self.proto_layer.distribution
|
||||
self.reconfigure_optimizers()
|
||||
|
||||
def remove_prototypes(self, indices):
|
||||
self.proto_layer.remove_components(indices)
|
||||
self.hparams.distribution = self.proto_layer.distribution
|
||||
self.reconfigure_optimizers()
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user