feat: metrics can be assigned to the different phases
This commit is contained in:
parent
46ec7b07d7
commit
736565b768
@ -23,6 +23,13 @@ ProtoTorch Models Plugins
|
|||||||
|
|
||||||
custom
|
custom
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:hidden:
|
||||||
|
:maxdepth: 3
|
||||||
|
:caption: Proto Y Architecture
|
||||||
|
|
||||||
|
y-architecture
|
||||||
|
|
||||||
About
|
About
|
||||||
-----------------------------------------
|
-----------------------------------------
|
||||||
`Prototorch Models <https://github.com/si-cim/prototorch_models>`_ is a Plugin
|
`Prototorch Models <https://github.com/si-cim/prototorch_models>`_ is a Plugin
|
||||||
@ -33,8 +40,10 @@ prototype-based Machine Learning algorithms using `PyTorch-Lightning
|
|||||||
Library
|
Library
|
||||||
-----------------------------------------
|
-----------------------------------------
|
||||||
Prototorch Models delivers many application ready models.
|
Prototorch Models delivers many application ready models.
|
||||||
These models have been published in the past and have been adapted to the Prototorch library.
|
These models have been published in the past and have been adapted to the
|
||||||
|
Prototorch library.
|
||||||
|
|
||||||
Customizable
|
Customizable
|
||||||
-----------------------------------------
|
-----------------------------------------
|
||||||
Prototorch Models also contains the building blocks to build own models with PyTorch-Lightning and Prototorch.
|
Prototorch Models also contains the building blocks to build own models with
|
||||||
|
PyTorch-Lightning and Prototorch.
|
||||||
|
@ -71,7 +71,7 @@ Probabilistic Models
|
|||||||
Probabilistic variants assume, that the prototypes generate a probability distribution over the classes.
|
Probabilistic variants assume, that the prototypes generate a probability distribution over the classes.
|
||||||
For a test sample they return a distribution instead of a class assignment.
|
For a test sample they return a distribution instead of a class assignment.
|
||||||
|
|
||||||
The following two algorihms were presented by :cite:t:`seo2003` .
|
The following two algorithms were presented by :cite:t:`seo2003` .
|
||||||
Every prototypes is a center of a gaussian distribution of its class, generating a mixture model.
|
Every prototypes is a center of a gaussian distribution of its class, generating a mixture model.
|
||||||
|
|
||||||
.. autoclass:: prototorch.models.probabilistic.SLVQ
|
.. autoclass:: prototorch.models.probabilistic.SLVQ
|
||||||
@ -80,7 +80,7 @@ Every prototypes is a center of a gaussian distribution of its class, generating
|
|||||||
.. autoclass:: prototorch.models.probabilistic.RSLVQ
|
.. autoclass:: prototorch.models.probabilistic.RSLVQ
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incooperate the winning rank into the prior probability calculation.
|
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incorporate the winning rank into the prior probability calculation.
|
||||||
And second use divergence as loss function.
|
And second use divergence as loss function.
|
||||||
|
|
||||||
.. autoclass:: prototorch.models.probabilistic.PLVQ
|
.. autoclass:: prototorch.models.probabilistic.PLVQ
|
||||||
@ -106,7 +106,7 @@ Visualization
|
|||||||
Visualization is very specific to its application.
|
Visualization is very specific to its application.
|
||||||
PrototorchModels delivers visualization for two dimensional data and image data.
|
PrototorchModels delivers visualization for two dimensional data and image data.
|
||||||
|
|
||||||
The visulizations can be shown in a seperate window and inside a tensorboard.
|
The visualizations can be shown in a separate window and inside a tensorboard.
|
||||||
|
|
||||||
.. automodule:: prototorch.models.vis
|
.. automodule:: prototorch.models.vis
|
||||||
:members:
|
:members:
|
||||||
|
71
docs/source/y-architecture.rst
Normal file
71
docs/source/y-architecture.rst
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
.. Documentation of the updated Architecture.
|
||||||
|
|
||||||
|
Proto Y Architecture
|
||||||
|
========================================
|
||||||
|
|
||||||
|
Overview
|
||||||
|
****************************************
|
||||||
|
|
||||||
|
The Proto Y Architecture is a framework for abstract prototype learning methods.
|
||||||
|
|
||||||
|
It divides the problem into multiple steps:
|
||||||
|
|
||||||
|
* **Components** : Recalling the position and metadata of the components/prototypes.
|
||||||
|
* **Backbone** : Apply a mapping function to data and prototypes.
|
||||||
|
* **Comparison** : Calculate a dissimilarity based on the latent positions.
|
||||||
|
* **Competition** : Calculate competition values based on the comparison and the metadata.
|
||||||
|
* **Loss** : Calculate the loss based on the competition values
|
||||||
|
* **Inference** : Predict the output based on the competition values.
|
||||||
|
|
||||||
|
Depending on the phase (Training or Testing) Loss or Inference is used.
|
||||||
|
|
||||||
|
Inheritance Structure
|
||||||
|
****************************************
|
||||||
|
|
||||||
|
The Proto Y Architecture has a single base class that defines all steps and hooks
|
||||||
|
of the architecture.
|
||||||
|
|
||||||
|
.. autoclass:: prototorch.y.architectures.base.BaseYArchitecture
|
||||||
|
|
||||||
|
**Steps**
|
||||||
|
|
||||||
|
Components
|
||||||
|
|
||||||
|
.. automethod:: init_components
|
||||||
|
.. automethod:: components
|
||||||
|
|
||||||
|
Backbone
|
||||||
|
|
||||||
|
.. automethod:: init_backbone
|
||||||
|
.. automethod:: backbone
|
||||||
|
|
||||||
|
Comparison
|
||||||
|
|
||||||
|
.. automethod:: init_comparison
|
||||||
|
.. automethod:: comparison
|
||||||
|
|
||||||
|
Competition
|
||||||
|
|
||||||
|
.. automethod:: init_competition
|
||||||
|
.. automethod:: competition
|
||||||
|
|
||||||
|
Loss
|
||||||
|
|
||||||
|
.. automethod:: init_loss
|
||||||
|
.. automethod:: loss
|
||||||
|
|
||||||
|
Inference
|
||||||
|
|
||||||
|
.. automethod:: init_inference
|
||||||
|
.. automethod:: inference
|
||||||
|
|
||||||
|
**Hooks**
|
||||||
|
|
||||||
|
Torchmetric
|
||||||
|
|
||||||
|
.. automethod:: register_torchmetric
|
||||||
|
|
||||||
|
Hyperparameters
|
||||||
|
****************************************
|
||||||
|
Every model implemented with the Proto Y Architecture has a set of hyperparameters,
|
||||||
|
which is stored in the ``HyperParameters`` attribute of the architecture.
|
@ -1,7 +1,10 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.core import SMCI
|
from prototorch.core import SMCI
|
||||||
|
from prototorch.y.architectures.base import Steps
|
||||||
from prototorch.y.callbacks import (
|
from prototorch.y.callbacks import (
|
||||||
LogTorchmetricCallback,
|
LogTorchmetricCallback,
|
||||||
PlotLambdaMatrixToTensorboard,
|
PlotLambdaMatrixToTensorboard,
|
||||||
@ -9,7 +12,9 @@ from prototorch.y.callbacks import (
|
|||||||
)
|
)
|
||||||
from prototorch.y.library.gmlvq import GMLVQ
|
from prototorch.y.library.gmlvq import GMLVQ
|
||||||
from pytorch_lightning.callbacks import EarlyStopping
|
from pytorch_lightning.callbacks import EarlyStopping
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, random_split
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
# ##############################################################################
|
# ##############################################################################
|
||||||
|
|
||||||
@ -20,22 +25,42 @@ def main():
|
|||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Iris()
|
full_dataset = pt.datasets.Iris()
|
||||||
|
full_count = len(full_dataset)
|
||||||
|
|
||||||
|
train_count = int(full_count * 0.5)
|
||||||
|
val_count = int(full_count * 0.4)
|
||||||
|
test_count = int(full_count * 0.1)
|
||||||
|
|
||||||
|
train_dataset, val_dataset, test_dataset = random_split(
|
||||||
|
full_dataset, (train_count, val_count, test_count))
|
||||||
|
|
||||||
# Dataloader
|
# Dataloader
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_ds,
|
train_dataset,
|
||||||
batch_size=32,
|
batch_size=1,
|
||||||
num_workers=0,
|
num_workers=4,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=0,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
# HYPERPARAMETERS
|
# HYPERPARAMETERS
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
# Select Initializer
|
# Select Initializer
|
||||||
components_initializer = SMCI(train_ds)
|
components_initializer = SMCI(full_dataset)
|
||||||
|
|
||||||
# Define Hyperparameters
|
# Define Hyperparameters
|
||||||
hyperparameters = GMLVQ.HyperParameters(
|
hyperparameters = GMLVQ.HyperParameters(
|
||||||
@ -51,17 +76,23 @@ def main():
|
|||||||
# Create Model
|
# Create Model
|
||||||
model = GMLVQ(hyperparameters)
|
model = GMLVQ(hyperparameters)
|
||||||
|
|
||||||
print(model.hparams)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
# TRAINING
|
# TRAINING
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
|
|
||||||
# Controlling Callbacks
|
# Controlling Callbacks
|
||||||
stopping_criterion = LogTorchmetricCallback(
|
recall = LogTorchmetricCallback(
|
||||||
'recall',
|
'training_recall',
|
||||||
torchmetrics.Recall,
|
torchmetrics.Recall,
|
||||||
num_classes=3,
|
num_classes=3,
|
||||||
|
step=Steps.TRAINING,
|
||||||
|
)
|
||||||
|
|
||||||
|
stopping_criterion = LogTorchmetricCallback(
|
||||||
|
'validation_recall',
|
||||||
|
torchmetrics.Recall,
|
||||||
|
num_classes=3,
|
||||||
|
step=Steps.VALIDATION,
|
||||||
)
|
)
|
||||||
|
|
||||||
es = EarlyStopping(
|
es = EarlyStopping(
|
||||||
@ -71,18 +102,23 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Visualization Callback
|
# Visualization Callback
|
||||||
vis = VisGMLVQ2D(data=train_ds)
|
vis = VisGMLVQ2D(data=full_dataset)
|
||||||
|
|
||||||
# Define trainer
|
# Define trainer
|
||||||
trainer = pl.Trainer(callbacks=[
|
trainer = pl.Trainer(
|
||||||
vis,
|
callbacks=[
|
||||||
stopping_criterion,
|
vis,
|
||||||
es,
|
recall,
|
||||||
PlotLambdaMatrixToTensorboard(),
|
stopping_criterion,
|
||||||
], )
|
es,
|
||||||
|
PlotLambdaMatrixToTensorboard(),
|
||||||
|
],
|
||||||
|
max_epochs=100,
|
||||||
|
)
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader, val_loader)
|
||||||
|
trainer.test(model, test_loader)
|
||||||
|
|
||||||
# Manual save
|
# Manual save
|
||||||
trainer.save_checkpoint("./y_arch.ckpt")
|
trainer.save_checkpoint("./y_arch.ckpt")
|
||||||
@ -93,8 +129,6 @@ def main():
|
|||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(new_model.hparams)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -13,15 +13,34 @@ import torch
|
|||||||
from torchmetrics import Metric
|
from torchmetrics import Metric
|
||||||
|
|
||||||
|
|
||||||
|
class Steps(enumerate):
|
||||||
|
TRAINING = "training"
|
||||||
|
VALIDATION = "validation"
|
||||||
|
TEST = "test"
|
||||||
|
PREDICT = "predict"
|
||||||
|
|
||||||
|
|
||||||
class BaseYArchitecture(pl.LightningModule):
|
class BaseYArchitecture(pl.LightningModule):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HyperParameters:
|
class HyperParameters:
|
||||||
|
"""
|
||||||
|
Add all hyperparameters in the inherited class.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
# Fields
|
# Fields
|
||||||
registered_metrics: dict[type[Metric], Metric] = {}
|
registered_metrics: dict[str, dict[type[Metric], Metric]] = {
|
||||||
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
|
Steps.TRAINING: {},
|
||||||
|
Steps.VALIDATION: {},
|
||||||
|
Steps.TEST: {},
|
||||||
|
}
|
||||||
|
registered_metric_callbacks: dict[str, dict[type[Metric],
|
||||||
|
set[Callable]]] = {
|
||||||
|
Steps.TRAINING: {},
|
||||||
|
Steps.VALIDATION: {},
|
||||||
|
Steps.TEST: {},
|
||||||
|
}
|
||||||
|
|
||||||
# Type Hints for Necessary Fields
|
# Type Hints for Necessary Fields
|
||||||
components_layer: torch.nn.Module
|
components_layer: torch.nn.Module
|
||||||
@ -41,7 +60,7 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
|
|
||||||
# Common Steps
|
# Common Steps
|
||||||
self.init_components(hparams)
|
self.init_components(hparams)
|
||||||
self.init_latent(hparams)
|
self.init_backbone(hparams)
|
||||||
self.init_comparison(hparams)
|
self.init_comparison(hparams)
|
||||||
self.init_competition(hparams)
|
self.init_competition(hparams)
|
||||||
|
|
||||||
@ -53,7 +72,7 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
|
|
||||||
# external API
|
# external API
|
||||||
def get_competition(self, batch, components):
|
def get_competition(self, batch, components):
|
||||||
latent_batch, latent_components = self.latent(batch, components)
|
latent_batch, latent_components = self.backbone(batch, components)
|
||||||
# TODO: => Latent Hook
|
# TODO: => Latent Hook
|
||||||
comparison_tensor = self.comparison(latent_batch, latent_components)
|
comparison_tensor = self.comparison(latent_batch, latent_components)
|
||||||
# TODO: => Comparison Hook
|
# TODO: => Comparison Hook
|
||||||
@ -92,27 +111,43 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
return self.loss(comparison_tensor, batch, components)
|
return self.loss(comparison_tensor, batch, components)
|
||||||
|
|
||||||
# Empty Initialization
|
# Empty Initialization
|
||||||
# TODO: Docs
|
|
||||||
def init_components(self, hparams: HyperParameters) -> None:
|
def init_components(self, hparams: HyperParameters) -> None:
|
||||||
|
"""
|
||||||
|
All initialization necessary for the components step.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def init_latent(self, hparams: HyperParameters) -> None:
|
def init_backbone(self, hparams: HyperParameters) -> None:
|
||||||
|
"""
|
||||||
|
All initialization necessary for the backbone step.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||||
|
"""
|
||||||
|
All initialization necessary for the comparison step.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def init_competition(self, hparams: HyperParameters) -> None:
|
def init_competition(self, hparams: HyperParameters) -> None:
|
||||||
|
"""
|
||||||
|
All initialization necessary for the competition step.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def init_loss(self, hparams: HyperParameters) -> None:
|
def init_loss(self, hparams: HyperParameters) -> None:
|
||||||
|
"""
|
||||||
|
All initialization necessary for the loss step.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def init_inference(self, hparams: HyperParameters) -> None:
|
def init_inference(self, hparams: HyperParameters) -> None:
|
||||||
|
"""
|
||||||
|
All initialization necessary for the inference step.
|
||||||
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
# Empty Steps
|
# Empty Steps
|
||||||
# TODO: Type hints
|
|
||||||
def components(self):
|
def components(self):
|
||||||
"""
|
"""
|
||||||
This step has no input.
|
This step has no input.
|
||||||
@ -122,9 +157,9 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"The components step has no reasonable default.")
|
"The components step has no reasonable default.")
|
||||||
|
|
||||||
def latent(self, batch, components):
|
def backbone(self, batch, components):
|
||||||
"""
|
"""
|
||||||
The latent step receives the data batch and the components.
|
The backbone step receives the data batch and the components.
|
||||||
It can transform both by an arbitrary function.
|
It can transform both by an arbitrary function.
|
||||||
|
|
||||||
It returns the transformed batch and components, each of the same length as the original input.
|
It returns the transformed batch and components, each of the same length as the original input.
|
||||||
@ -173,52 +208,72 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
self,
|
self,
|
||||||
name: Callable,
|
name: Callable,
|
||||||
metric: type[Metric],
|
metric: type[Metric],
|
||||||
|
step: str = Steps.TRAINING,
|
||||||
**metric_kwargs,
|
**metric_kwargs,
|
||||||
):
|
):
|
||||||
if metric not in self.registered_metrics:
|
if step == Steps.PREDICT:
|
||||||
self.registered_metrics[metric] = metric(**metric_kwargs)
|
raise ValueError("Prediction metrics are not supported.")
|
||||||
self.registered_metric_callbacks[metric] = {name}
|
|
||||||
else:
|
|
||||||
self.registered_metric_callbacks[metric].add(name)
|
|
||||||
|
|
||||||
def update_metrics_step(self, batch):
|
if metric not in self.registered_metrics:
|
||||||
|
self.registered_metrics[step][metric] = metric(**metric_kwargs)
|
||||||
|
self.registered_metric_callbacks[step][metric] = {name}
|
||||||
|
else:
|
||||||
|
self.registered_metric_callbacks[step][metric].add(name)
|
||||||
|
|
||||||
|
def update_metrics_step(self, batch, step):
|
||||||
# Prediction Metrics
|
# Prediction Metrics
|
||||||
preds = self(batch)
|
preds = self(batch)
|
||||||
|
|
||||||
x, y = batch
|
x, y = batch
|
||||||
for metric in self.registered_metrics:
|
for metric in self.registered_metrics[step]:
|
||||||
instance = self.registered_metrics[metric].to(self.device)
|
instance = self.registered_metrics[step][metric].to(self.device)
|
||||||
instance(y, preds)
|
instance(y, preds)
|
||||||
|
|
||||||
def update_metrics_epoch(self):
|
def update_metrics_epoch(self, step):
|
||||||
for metric in self.registered_metrics:
|
for metric in self.registered_metrics[step]:
|
||||||
instance = self.registered_metrics[metric].to(self.device)
|
instance = self.registered_metrics[step][metric].to(self.device)
|
||||||
value = instance.compute()
|
value = instance.compute()
|
||||||
|
|
||||||
for callback in self.registered_metric_callbacks[metric]:
|
for callback in self.registered_metric_callbacks[step][metric]:
|
||||||
callback(value, self)
|
callback(value, self)
|
||||||
|
|
||||||
instance.reset()
|
instance.reset()
|
||||||
|
|
||||||
# Lightning Hooks
|
# Lightning steps
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
# Steps
|
# >>>> Training
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
self.update_metrics_step([torch.clone(el) for el in batch])
|
self.update_metrics_step(batch, Steps.TRAINING)
|
||||||
|
|
||||||
return self.loss_forward(batch)
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
|
||||||
return self.loss_forward(batch)
|
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
|
||||||
return self.loss_forward(batch)
|
|
||||||
|
|
||||||
# Other Hooks
|
|
||||||
def training_epoch_end(self, outs) -> None:
|
def training_epoch_end(self, outs) -> None:
|
||||||
self.update_metrics_epoch()
|
self.update_metrics_epoch(Steps.TRAINING)
|
||||||
|
|
||||||
|
# >>>> Validation
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
self.update_metrics_step(batch, Steps.VALIDATION)
|
||||||
|
|
||||||
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
|
def validation_epoch_end(self, outs) -> None:
|
||||||
|
self.update_metrics_epoch(Steps.VALIDATION)
|
||||||
|
|
||||||
|
# >>>> Test
|
||||||
|
def test_step(self, batch, batch_idx):
|
||||||
|
self.update_metrics_step(batch, Steps.TEST)
|
||||||
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
|
def test_epoch_end(self, outs) -> None:
|
||||||
|
self.update_metrics_epoch(Steps.TEST)
|
||||||
|
|
||||||
|
# >>>> Prediction
|
||||||
|
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||||
|
return self.predict(batch)
|
||||||
|
|
||||||
|
# Check points
|
||||||
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
||||||
|
# Compatible with Lightning
|
||||||
checkpoint["hyper_parameters"] = {
|
checkpoint["hyper_parameters"] = {
|
||||||
'hparams': checkpoint["hyper_parameters"]
|
'hparams': checkpoint["hyper_parameters"]
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
@ -8,7 +9,7 @@ import torchmetrics
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from prototorch.models.vis import Vis2DAbstract
|
from prototorch.models.vis import Vis2DAbstract
|
||||||
from prototorch.utils.utils import mesh2d
|
from prototorch.utils.utils import mesh2d
|
||||||
from prototorch.y.architectures.base import BaseYArchitecture
|
from prototorch.y.architectures.base import BaseYArchitecture, Steps
|
||||||
from prototorch.y.library.gmlvq import GMLVQ
|
from prototorch.y.library.gmlvq import GMLVQ
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
|
|
||||||
@ -34,13 +35,13 @@ class LogTorchmetricCallback(pl.Callback):
|
|||||||
self,
|
self,
|
||||||
name,
|
name,
|
||||||
metric: Type[torchmetrics.Metric],
|
metric: Type[torchmetrics.Metric],
|
||||||
on="prediction",
|
step: str = Steps.TRAINING,
|
||||||
**metric_kwargs,
|
**metric_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
self.metric_kwargs = metric_kwargs
|
self.metric_kwargs = metric_kwargs
|
||||||
self.on = on
|
self.step = step
|
||||||
|
|
||||||
def setup(
|
def setup(
|
||||||
self,
|
self,
|
||||||
@ -48,14 +49,12 @@ class LogTorchmetricCallback(pl.Callback):
|
|||||||
pl_module: BaseYArchitecture,
|
pl_module: BaseYArchitecture,
|
||||||
stage: Optional[str] = None,
|
stage: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.on == "prediction":
|
pl_module.register_torchmetric(
|
||||||
pl_module.register_torchmetric(
|
self,
|
||||||
self,
|
self.metric,
|
||||||
self.metric,
|
step=self.step,
|
||||||
**self.metric_kwargs,
|
**self.metric_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise ValueError(f"{self.on} is no valid metric hook")
|
|
||||||
|
|
||||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||||
pl_module.log(self.name, value)
|
pl_module.log(self.name, value)
|
||||||
|
Loading…
Reference in New Issue
Block a user