fix: update hparams.distribution
as it changes during training
This commit is contained in:
parent
15e7232747
commit
dd696ea1e0
@ -3,12 +3,11 @@
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.core.initializers import ZerosCompInitializer
|
|
||||||
|
|
||||||
from ..core.competitions import WTAC
|
from ..core.competitions import WTAC
|
||||||
from ..core.components import Components, LabeledComponents
|
from ..core.components import Components, LabeledComponents
|
||||||
from ..core.distances import euclidean_distance
|
from ..core.distances import euclidean_distance
|
||||||
from ..core.initializers import LabelsInitializer
|
from ..core.initializers import LabelsInitializer, ZerosCompInitializer
|
||||||
from ..core.pooling import stratified_min_pooling
|
from ..core.pooling import stratified_min_pooling
|
||||||
from ..nn.wrappers import LambdaLayer
|
from ..nn.wrappers import LambdaLayer
|
||||||
|
|
||||||
@ -76,10 +75,12 @@ class PrototypeModel(ProtoTorchBolt):
|
|||||||
|
|
||||||
def add_prototypes(self, *args, **kwargs):
|
def add_prototypes(self, *args, **kwargs):
|
||||||
self.proto_layer.add_components(*args, **kwargs)
|
self.proto_layer.add_components(*args, **kwargs)
|
||||||
|
self.hparams.distribution = self.proto_layer.distribution
|
||||||
self.reconfigure_optimizers()
|
self.reconfigure_optimizers()
|
||||||
|
|
||||||
def remove_prototypes(self, indices):
|
def remove_prototypes(self, indices):
|
||||||
self.proto_layer.remove_components(indices)
|
self.proto_layer.remove_components(indices)
|
||||||
|
self.hparams.distribution = self.proto_layer.distribution
|
||||||
self.reconfigure_optimizers()
|
self.reconfigure_optimizers()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user