fix: dont save prototype initializer in yarch checkpoint

This commit is contained in:
Alexander Engelsberger 2022-06-12 11:12:55 +02:00
parent be7d7f43bd
commit 60d2a1d2c9
2 changed files with 23 additions and 14 deletions

View File

@ -6,13 +6,7 @@ Network architecture for Component based Learning.
from __future__ import annotations
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Set,
Type,
)
from typing import Any, Callable
import pytorch_lightning as pl
import torch
@ -34,12 +28,15 @@ class BaseYArchitecture(pl.LightningModule):
def __init__(self, hparams) -> None:
if type(hparams) is dict:
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams)
else:
self.save_hyperparameters(hparams.__dict__)
super().__init__()
self.save_hyperparameters(hparams.__dict__)
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
@ -220,6 +217,7 @@ class BaseYArchitecture(pl.LightningModule):
self.update_metrics_epoch()
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint["hyper_parameters"]["component_initializer"] = None
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}

View File

@ -4,6 +4,7 @@ from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
ZerosCompInitializer,
)
from prototorch.y import BaseYArchitecture
@ -30,11 +31,21 @@ class SupervisedArchitecture(BaseYArchitecture):
# Steps
# ----------------------------------------------------------------------------------------------------
def init_components(self, hparams: HyperParameters):
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
if hparams.component_initializer is not None:
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
proto_shape = self.components_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
)
# Properties
# ----------------------------------------------------------------------------------------------------