fix: update hparams.distribution as it changes during training

This commit is contained in:
Jensun Ravichandran 2022-02-02 21:53:03 +01:00
parent 15e7232747
commit dd696ea1e0
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921

View File

@ -3,12 +3,11 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.core.initializers import ZerosCompInitializer
from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents
from ..core.distances import euclidean_distance
from ..core.initializers import LabelsInitializer
from ..core.initializers import LabelsInitializer, ZerosCompInitializer
from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
@ -76,10 +75,12 @@ class PrototypeModel(ProtoTorchBolt):
def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
self.hparams.distribution = self.proto_layer.distribution
self.reconfigure_optimizers()
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.hparams.distribution = self.proto_layer.distribution
self.reconfigure_optimizers()