From c5dd194bb362eec75c12a6fe23838b85bd02147d Mon Sep 17 00:00:00 2001 From: julius Date: Mon, 18 Nov 2024 13:39:35 +0100 Subject: [PATCH] starting point to develop an example script --- hello.py | 6 ----- main.py | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 6 deletions(-) delete mode 100644 hello.py create mode 100644 main.py diff --git a/hello.py b/hello.py deleted file mode 100644 index c53e311..0000000 --- a/hello.py +++ /dev/null @@ -1,6 +0,0 @@ -def main(): - print("Hello from experiment!") - - -if __name__ == "__main__": - main() diff --git a/main.py b/main.py new file mode 100644 index 0000000..8f98abb --- /dev/null +++ b/main.py @@ -0,0 +1,79 @@ +"""Write a short description of your experiment here.""" + +import logging +import lightning as L +import torch +from torch.utils.data import DataLoader, random_split + +from protothor.callbacks.normalization import OmegaTraceNormalization +from protothor.data.preprocessing import Standardizer +from lightning.pytorch.callbacks import EarlyStopping +from protothor.data.toy import Iris +from protothor.functional.initialization import zero_initialization +from protothor.lightning.metric_module import MetricModule +from protothor.models.gmlvq import generate_model as generate_gmlvq +from protothor.nn.collectable_loss import EvaluateExposedLosses +from protothor.nn.container import find_instances + +logging.basicConfig(level=logging.INFO) + + +def main(): + # 1 - Get Dataset + data = Iris([0, 2]) + train_ds, val_ds = random_split(data, [0.7, 0.3]) + + standardize = Standardizer(train_ds) + + # 2 - Create Dataloaders + train_loader = DataLoader( + standardize(train_ds), + shuffle=True, + batch_size=len(train_ds), + ) + val_loader = DataLoader( + standardize(val_ds), + shuffle=False, + batch_size=len(val_ds), + ) + + # 3 - Initialize Prototypes + labels, positions = zero_initialization(data) + + # 4 - Generate Torch Model + model = generate_gmlvq(positions, labels) + omega_layer = find_instances(model, torch.nn.Linear)[0] + + # 5 - Initialize Lightning Module + module = MetricModule(model, EvaluateExposedLosses(model)) + + # 6 - Define Callbacks + + callbacks = [ + OmegaTraceNormalization(omega_layer), + ] + callbacks.append( + EarlyStopping( + monitor="accuracy/validation", + min_delta=0.0002, + patience=50, + mode="max", + verbose=False, + check_on_train_epoch_end=True, + ) + ) + + # 7 - Define Trainer + trainer = L.Trainer(detect_anomaly=True, max_epochs=1000, callbacks=callbacks) + + # 8 - Train Model + trainer.fit(module, train_loader, val_loader) + + # 9 - Analyse results + omega = omega_layer.weight.detach() + classification_correlation_matrix = omega.T @ omega + logging.info(classification_correlation_matrix) + + +if __name__ == "__main__": + main()