fix: dont save prototype initializer in yarch checkpoint
This commit is contained in:
parent
be7d7f43bd
commit
60d2a1d2c9
@ -6,13 +6,7 @@ Network architecture for Component based Learning.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Set,
|
||||
Type,
|
||||
)
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
@ -34,12 +28,15 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
def __init__(self, hparams) -> None:
|
||||
if type(hparams) is dict:
|
||||
self.save_hyperparameters(hparams)
|
||||
# TODO: => Move into Component Child
|
||||
del hparams["initialized_proto_shape"]
|
||||
hparams = self.HyperParameters(**hparams)
|
||||
else:
|
||||
self.save_hyperparameters(hparams.__dict__)
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(hparams.__dict__)
|
||||
|
||||
# Common Steps
|
||||
self.init_components(hparams)
|
||||
self.init_latent(hparams)
|
||||
@ -220,6 +217,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
self.update_metrics_epoch()
|
||||
|
||||
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
||||
checkpoint["hyper_parameters"]["component_initializer"] = None
|
||||
checkpoint["hyper_parameters"] = {
|
||||
'hparams': checkpoint["hyper_parameters"]
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ from prototorch.core.components import LabeledComponents
|
||||
from prototorch.core.initializers import (
|
||||
AbstractComponentsInitializer,
|
||||
LabelsInitializer,
|
||||
ZerosCompInitializer,
|
||||
)
|
||||
from prototorch.y import BaseYArchitecture
|
||||
|
||||
@ -30,11 +31,21 @@ class SupervisedArchitecture(BaseYArchitecture):
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def init_components(self, hparams: HyperParameters):
|
||||
if hparams.component_initializer is not None:
|
||||
self.components_layer = LabeledComponents(
|
||||
distribution=hparams.distribution,
|
||||
components_initializer=hparams.component_initializer,
|
||||
labels_initializer=LabelsInitializer(),
|
||||
)
|
||||
proto_shape = self.components_layer.components.shape[1:]
|
||||
self.hparams["initialized_proto_shape"] = proto_shape
|
||||
else:
|
||||
# when restoring a checkpointed model
|
||||
self.components_layer = LabeledComponents(
|
||||
distribution=hparams.distribution,
|
||||
components_initializer=ZerosCompInitializer(
|
||||
self.hparams["initialized_proto_shape"]),
|
||||
)
|
||||
|
||||
# Properties
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
Loading…
Reference in New Issue
Block a user