Compare commits

..

2 Commits

Author SHA1 Message Date
c5dd194bb3
starting point to develop an example script 2024-11-18 13:39:35 +01:00
f103e1b00a
add protoThor 2024-11-14 15:49:35 +01:00
4 changed files with 1309 additions and 7 deletions

View File

@ -1,6 +0,0 @@
def main():
print("Hello from experiment!")
if __name__ == "__main__":
main()

79
main.py Normal file
View File

@ -0,0 +1,79 @@
"""Write a short description of your experiment here."""
import logging
import lightning as L
import torch
from torch.utils.data import DataLoader, random_split
from protothor.callbacks.normalization import OmegaTraceNormalization
from protothor.data.preprocessing import Standardizer
from lightning.pytorch.callbacks import EarlyStopping
from protothor.data.toy import Iris
from protothor.functional.initialization import zero_initialization
from protothor.lightning.metric_module import MetricModule
from protothor.models.gmlvq import generate_model as generate_gmlvq
from protothor.nn.collectable_loss import EvaluateExposedLosses
from protothor.nn.container import find_instances
logging.basicConfig(level=logging.INFO)
def main():
# 1 - Get Dataset
data = Iris([0, 2])
train_ds, val_ds = random_split(data, [0.7, 0.3])
standardize = Standardizer(train_ds)
# 2 - Create Dataloaders
train_loader = DataLoader(
standardize(train_ds),
shuffle=True,
batch_size=len(train_ds),
)
val_loader = DataLoader(
standardize(val_ds),
shuffle=False,
batch_size=len(val_ds),
)
# 3 - Initialize Prototypes
labels, positions = zero_initialization(data)
# 4 - Generate Torch Model
model = generate_gmlvq(positions, labels)
omega_layer = find_instances(model, torch.nn.Linear)[0]
# 5 - Initialize Lightning Module
module = MetricModule(model, EvaluateExposedLosses(model))
# 6 - Define Callbacks
callbacks = [
OmegaTraceNormalization(omega_layer),
]
callbacks.append(
EarlyStopping(
monitor="accuracy/validation",
min_delta=0.0002,
patience=50,
mode="max",
verbose=False,
check_on_train_epoch_end=True,
)
)
# 7 - Define Trainer
trainer = L.Trainer(detect_anomaly=True, max_epochs=1000, callbacks=callbacks)
# 8 - Train Model
trainer.fit(module, train_loader, val_loader)
# 9 - Analyse results
omega = omega_layer.weight.detach()
classification_correlation_matrix = omega.T @ omega
logging.info(classification_correlation_matrix)
if __name__ == "__main__":
main()

View File

@ -4,4 +4,26 @@ version = "0.1.0"
description = "Add your description here" description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [] dependencies = [
"protothor>=0.1.1",
"protothor-datarecords>=0.8.1",
]
[tool.uv.sources]
protothor = [{index = "protothor-gitlab"}]
protothor-datarecords = [{index = "protothor-datarecords-gitlab"}]
[[tool.uv.index]]
name = "protothor-datarecords-gitlab"
url = "https://git.hs-mittweida.de/api/v4/projects/2428/packages/pypi/simple"
default = false
[[tool.uv.index]]
name = "protothor-gitlab"
url = "https://git.hs-mittweida.de/api/v4/projects/2346/packages/pypi/simple"
default = false
[[tool.uv.index]]
name = "torch-cpu"
url = "https://download.pytorch.org/whl/cpu"
default = false

1207
uv.lock

File diff suppressed because it is too large Load Diff