fix: update hparams.distribution as it changes during training

This commit is contained in:
Jensun Ravichandran 2022-02-02 21:53:03 +01:00
parent 15e7232747
commit dd696ea1e0
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921

View File

@ -3,12 +3,11 @@
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.core.initializers import ZerosCompInitializer
from ..core.competitions import WTAC from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents from ..core.components import Components, LabeledComponents
from ..core.distances import euclidean_distance from ..core.distances import euclidean_distance
from ..core.initializers import LabelsInitializer from ..core.initializers import LabelsInitializer, ZerosCompInitializer
from ..core.pooling import stratified_min_pooling from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer from ..nn.wrappers import LambdaLayer
@ -76,10 +75,12 @@ class PrototypeModel(ProtoTorchBolt):
def add_prototypes(self, *args, **kwargs): def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs) self.proto_layer.add_components(*args, **kwargs)
self.hparams.distribution = self.proto_layer.distribution
self.reconfigure_optimizers() self.reconfigure_optimizers()
def remove_prototypes(self, indices): def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices) self.proto_layer.remove_components(indices)
self.hparams.distribution = self.proto_layer.distribution
self.reconfigure_optimizers() self.reconfigure_optimizers()