"""Write a short description of your experiment here.""" import logging import lightning as L import torch from torch.utils.data import DataLoader, random_split from protothor.callbacks.normalization import OmegaTraceNormalization from protothor.data.preprocessing import Standardizer from lightning.pytorch.callbacks import EarlyStopping from protothor.data.toy import Iris from protothor.functional.initialization import zero_initialization from protothor.lightning.metric_module import MetricModule from protothor.models.gmlvq import generate_model as generate_gmlvq from protothor.nn.collectable_loss import EvaluateExposedLosses from protothor.nn.container import find_instances logging.basicConfig(level=logging.INFO) def main(): # 1 - Get Dataset data = Iris([0, 2]) train_ds, val_ds = random_split(data, [0.7, 0.3]) standardize = Standardizer(train_ds) # 2 - Create Dataloaders train_loader = DataLoader( standardize(train_ds), shuffle=True, batch_size=len(train_ds), ) val_loader = DataLoader( standardize(val_ds), shuffle=False, batch_size=len(val_ds), ) # 3 - Initialize Prototypes labels, positions = zero_initialization(data) # 4 - Generate Torch Model model = generate_gmlvq(positions, labels) omega_layer = find_instances(model, torch.nn.Linear)[0] # 5 - Initialize Lightning Module module = MetricModule(model, EvaluateExposedLosses(model)) # 6 - Define Callbacks callbacks = [ OmegaTraceNormalization(omega_layer), ] callbacks.append( EarlyStopping( monitor="accuracy/validation", min_delta=0.0002, patience=50, mode="max", verbose=False, check_on_train_epoch_end=True, ) ) # 7 - Define Trainer trainer = L.Trainer(detect_anomaly=True, max_epochs=1000, callbacks=callbacks) # 8 - Train Model trainer.fit(module, train_loader, val_loader) # 9 - Analyse results omega = omega_layer.weight.detach() classification_correlation_matrix = omega.T @ omega logging.info(classification_correlation_matrix) if __name__ == "__main__": main()