Compare commits
No commits in common. "c5dd194bb362eec75c12a6fe23838b85bd02147d" and "39486567b241f5e3558685dc81242e93fca2d4bf" have entirely different histories.
c5dd194bb3
...
39486567b2
6
hello.py
Normal file
6
hello.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
def main():
|
||||||
|
print("Hello from experiment!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
79
main.py
79
main.py
@ -1,79 +0,0 @@
|
|||||||
"""Write a short description of your experiment here."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import lightning as L
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader, random_split
|
|
||||||
|
|
||||||
from protothor.callbacks.normalization import OmegaTraceNormalization
|
|
||||||
from protothor.data.preprocessing import Standardizer
|
|
||||||
from lightning.pytorch.callbacks import EarlyStopping
|
|
||||||
from protothor.data.toy import Iris
|
|
||||||
from protothor.functional.initialization import zero_initialization
|
|
||||||
from protothor.lightning.metric_module import MetricModule
|
|
||||||
from protothor.models.gmlvq import generate_model as generate_gmlvq
|
|
||||||
from protothor.nn.collectable_loss import EvaluateExposedLosses
|
|
||||||
from protothor.nn.container import find_instances
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# 1 - Get Dataset
|
|
||||||
data = Iris([0, 2])
|
|
||||||
train_ds, val_ds = random_split(data, [0.7, 0.3])
|
|
||||||
|
|
||||||
standardize = Standardizer(train_ds)
|
|
||||||
|
|
||||||
# 2 - Create Dataloaders
|
|
||||||
train_loader = DataLoader(
|
|
||||||
standardize(train_ds),
|
|
||||||
shuffle=True,
|
|
||||||
batch_size=len(train_ds),
|
|
||||||
)
|
|
||||||
val_loader = DataLoader(
|
|
||||||
standardize(val_ds),
|
|
||||||
shuffle=False,
|
|
||||||
batch_size=len(val_ds),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3 - Initialize Prototypes
|
|
||||||
labels, positions = zero_initialization(data)
|
|
||||||
|
|
||||||
# 4 - Generate Torch Model
|
|
||||||
model = generate_gmlvq(positions, labels)
|
|
||||||
omega_layer = find_instances(model, torch.nn.Linear)[0]
|
|
||||||
|
|
||||||
# 5 - Initialize Lightning Module
|
|
||||||
module = MetricModule(model, EvaluateExposedLosses(model))
|
|
||||||
|
|
||||||
# 6 - Define Callbacks
|
|
||||||
|
|
||||||
callbacks = [
|
|
||||||
OmegaTraceNormalization(omega_layer),
|
|
||||||
]
|
|
||||||
callbacks.append(
|
|
||||||
EarlyStopping(
|
|
||||||
monitor="accuracy/validation",
|
|
||||||
min_delta=0.0002,
|
|
||||||
patience=50,
|
|
||||||
mode="max",
|
|
||||||
verbose=False,
|
|
||||||
check_on_train_epoch_end=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 7 - Define Trainer
|
|
||||||
trainer = L.Trainer(detect_anomaly=True, max_epochs=1000, callbacks=callbacks)
|
|
||||||
|
|
||||||
# 8 - Train Model
|
|
||||||
trainer.fit(module, train_loader, val_loader)
|
|
||||||
|
|
||||||
# 9 - Analyse results
|
|
||||||
omega = omega_layer.weight.detach()
|
|
||||||
classification_correlation_matrix = omega.T @ omega
|
|
||||||
logging.info(classification_correlation_matrix)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -4,26 +4,4 @@ version = "0.1.0"
|
|||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = []
|
||||||
"protothor>=0.1.1",
|
|
||||||
"protothor-datarecords>=0.8.1",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
|
||||||
protothor = [{index = "protothor-gitlab"}]
|
|
||||||
protothor-datarecords = [{index = "protothor-datarecords-gitlab"}]
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "protothor-datarecords-gitlab"
|
|
||||||
url = "https://git.hs-mittweida.de/api/v4/projects/2428/packages/pypi/simple"
|
|
||||||
default = false
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "protothor-gitlab"
|
|
||||||
url = "https://git.hs-mittweida.de/api/v4/projects/2346/packages/pypi/simple"
|
|
||||||
default = false
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "torch-cpu"
|
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
|
||||||
default = false
|
|
||||||
|
Loading…
Reference in New Issue
Block a user