fix: dont save prototype initializer in yarch checkpoint

This commit is contained in:
Alexander Engelsberger 2022-06-12 11:12:55 +02:00
parent be7d7f43bd
commit 60d2a1d2c9
2 changed files with 23 additions and 14 deletions

View File

@ -6,13 +6,7 @@ Network architecture for Component based Learning.
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import Any, Callable
Any,
Callable,
Dict,
Set,
Type,
)
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
@ -34,12 +28,15 @@ class BaseYArchitecture(pl.LightningModule):
def __init__(self, hparams) -> None: def __init__(self, hparams) -> None:
if type(hparams) is dict: if type(hparams) is dict:
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams) hparams = self.HyperParameters(**hparams)
else:
self.save_hyperparameters(hparams.__dict__)
super().__init__() super().__init__()
self.save_hyperparameters(hparams.__dict__)
# Common Steps # Common Steps
self.init_components(hparams) self.init_components(hparams)
self.init_latent(hparams) self.init_latent(hparams)
@ -220,6 +217,7 @@ class BaseYArchitecture(pl.LightningModule):
self.update_metrics_epoch() self.update_metrics_epoch()
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint["hyper_parameters"]["component_initializer"] = None
checkpoint["hyper_parameters"] = { checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"] 'hparams': checkpoint["hyper_parameters"]
} }

View File

@ -4,6 +4,7 @@ from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import ( from prototorch.core.initializers import (
AbstractComponentsInitializer, AbstractComponentsInitializer,
LabelsInitializer, LabelsInitializer,
ZerosCompInitializer,
) )
from prototorch.y import BaseYArchitecture from prototorch.y import BaseYArchitecture
@ -30,11 +31,21 @@ class SupervisedArchitecture(BaseYArchitecture):
# Steps # Steps
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def init_components(self, hparams: HyperParameters): def init_components(self, hparams: HyperParameters):
if hparams.component_initializer is not None:
self.components_layer = LabeledComponents( self.components_layer = LabeledComponents(
distribution=hparams.distribution, distribution=hparams.distribution,
components_initializer=hparams.component_initializer, components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(), labels_initializer=LabelsInitializer(),
) )
proto_shape = self.components_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
)
# Properties # Properties
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------