fix: dont save prototype initializer in yarch checkpoint
This commit is contained in:
parent
be7d7f43bd
commit
60d2a1d2c9
@ -6,13 +6,7 @@ Network architecture for Component based Learning.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import Any, Callable
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Set,
|
|
||||||
Type,
|
|
||||||
)
|
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
@ -34,12 +28,15 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
|
|
||||||
def __init__(self, hparams) -> None:
|
def __init__(self, hparams) -> None:
|
||||||
if type(hparams) is dict:
|
if type(hparams) is dict:
|
||||||
|
self.save_hyperparameters(hparams)
|
||||||
|
# TODO: => Move into Component Child
|
||||||
|
del hparams["initialized_proto_shape"]
|
||||||
hparams = self.HyperParameters(**hparams)
|
hparams = self.HyperParameters(**hparams)
|
||||||
|
else:
|
||||||
|
self.save_hyperparameters(hparams.__dict__)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.save_hyperparameters(hparams.__dict__)
|
|
||||||
|
|
||||||
# Common Steps
|
# Common Steps
|
||||||
self.init_components(hparams)
|
self.init_components(hparams)
|
||||||
self.init_latent(hparams)
|
self.init_latent(hparams)
|
||||||
@ -220,6 +217,7 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
self.update_metrics_epoch()
|
self.update_metrics_epoch()
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
||||||
|
checkpoint["hyper_parameters"]["component_initializer"] = None
|
||||||
checkpoint["hyper_parameters"] = {
|
checkpoint["hyper_parameters"] = {
|
||||||
'hparams': checkpoint["hyper_parameters"]
|
'hparams': checkpoint["hyper_parameters"]
|
||||||
}
|
}
|
||||||
|
@ -4,6 +4,7 @@ from prototorch.core.components import LabeledComponents
|
|||||||
from prototorch.core.initializers import (
|
from prototorch.core.initializers import (
|
||||||
AbstractComponentsInitializer,
|
AbstractComponentsInitializer,
|
||||||
LabelsInitializer,
|
LabelsInitializer,
|
||||||
|
ZerosCompInitializer,
|
||||||
)
|
)
|
||||||
from prototorch.y import BaseYArchitecture
|
from prototorch.y import BaseYArchitecture
|
||||||
|
|
||||||
@ -30,11 +31,21 @@ class SupervisedArchitecture(BaseYArchitecture):
|
|||||||
# Steps
|
# Steps
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def init_components(self, hparams: HyperParameters):
|
def init_components(self, hparams: HyperParameters):
|
||||||
self.components_layer = LabeledComponents(
|
if hparams.component_initializer is not None:
|
||||||
distribution=hparams.distribution,
|
self.components_layer = LabeledComponents(
|
||||||
components_initializer=hparams.component_initializer,
|
distribution=hparams.distribution,
|
||||||
labels_initializer=LabelsInitializer(),
|
components_initializer=hparams.component_initializer,
|
||||||
)
|
labels_initializer=LabelsInitializer(),
|
||||||
|
)
|
||||||
|
proto_shape = self.components_layer.components.shape[1:]
|
||||||
|
self.hparams["initialized_proto_shape"] = proto_shape
|
||||||
|
else:
|
||||||
|
# when restoring a checkpointed model
|
||||||
|
self.components_layer = LabeledComponents(
|
||||||
|
distribution=hparams.distribution,
|
||||||
|
components_initializer=ZerosCompInitializer(
|
||||||
|
self.hparams["initialized_proto_shape"]),
|
||||||
|
)
|
||||||
|
|
||||||
# Properties
|
# Properties
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
Loading…
Reference in New Issue
Block a user