Compare commits
	
		
			2 Commits
		
	
	
		
			39486567b2
			...
			c5dd194bb3
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| c5dd194bb3 | |||
| f103e1b00a | 
							
								
								
									
										6
									
								
								hello.py
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								hello.py
									
									
									
									
									
								
							| @@ -1,6 +0,0 @@ | |||||||
| def main(): |  | ||||||
|     print("Hello from experiment!") |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     main() |  | ||||||
							
								
								
									
										79
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | |||||||
|  | """Write a short description of your experiment here.""" | ||||||
|  |  | ||||||
|  | import logging | ||||||
|  | import lightning as L | ||||||
|  | import torch | ||||||
|  | from torch.utils.data import DataLoader, random_split | ||||||
|  |  | ||||||
|  | from protothor.callbacks.normalization import OmegaTraceNormalization | ||||||
|  | from protothor.data.preprocessing import Standardizer | ||||||
|  | from lightning.pytorch.callbacks import EarlyStopping | ||||||
|  | from protothor.data.toy import Iris | ||||||
|  | from protothor.functional.initialization import zero_initialization | ||||||
|  | from protothor.lightning.metric_module import MetricModule | ||||||
|  | from protothor.models.gmlvq import generate_model as generate_gmlvq | ||||||
|  | from protothor.nn.collectable_loss import EvaluateExposedLosses | ||||||
|  | from protothor.nn.container import find_instances | ||||||
|  |  | ||||||
|  | logging.basicConfig(level=logging.INFO) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     # 1 - Get Dataset | ||||||
|  |     data = Iris([0, 2]) | ||||||
|  |     train_ds, val_ds = random_split(data, [0.7, 0.3]) | ||||||
|  |  | ||||||
|  |     standardize = Standardizer(train_ds) | ||||||
|  |  | ||||||
|  |     # 2 - Create Dataloaders | ||||||
|  |     train_loader = DataLoader( | ||||||
|  |         standardize(train_ds), | ||||||
|  |         shuffle=True, | ||||||
|  |         batch_size=len(train_ds), | ||||||
|  |     ) | ||||||
|  |     val_loader = DataLoader( | ||||||
|  |         standardize(val_ds), | ||||||
|  |         shuffle=False, | ||||||
|  |         batch_size=len(val_ds), | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # 3 - Initialize Prototypes | ||||||
|  |     labels, positions = zero_initialization(data) | ||||||
|  |  | ||||||
|  |     # 4 - Generate Torch Model | ||||||
|  |     model = generate_gmlvq(positions, labels) | ||||||
|  |     omega_layer = find_instances(model, torch.nn.Linear)[0] | ||||||
|  |  | ||||||
|  |     # 5 - Initialize Lightning Module | ||||||
|  |     module = MetricModule(model, EvaluateExposedLosses(model)) | ||||||
|  |  | ||||||
|  |     # 6 - Define Callbacks | ||||||
|  |  | ||||||
|  |     callbacks = [ | ||||||
|  |         OmegaTraceNormalization(omega_layer), | ||||||
|  |     ] | ||||||
|  |     callbacks.append( | ||||||
|  |         EarlyStopping( | ||||||
|  |             monitor="accuracy/validation", | ||||||
|  |             min_delta=0.0002, | ||||||
|  |             patience=50, | ||||||
|  |             mode="max", | ||||||
|  |             verbose=False, | ||||||
|  |             check_on_train_epoch_end=True, | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # 7 - Define Trainer | ||||||
|  |     trainer = L.Trainer(detect_anomaly=True, max_epochs=1000, callbacks=callbacks) | ||||||
|  |  | ||||||
|  |     # 8 - Train Model | ||||||
|  |     trainer.fit(module, train_loader, val_loader) | ||||||
|  |  | ||||||
|  |     # 9 - Analyse results | ||||||
|  |     omega = omega_layer.weight.detach() | ||||||
|  |     classification_correlation_matrix = omega.T @ omega | ||||||
|  |     logging.info(classification_correlation_matrix) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     main() | ||||||
| @@ -4,4 +4,26 @@ version = "0.1.0" | |||||||
| description = "Add your description here" | description = "Add your description here" | ||||||
| readme = "README.md" | readme = "README.md" | ||||||
| requires-python = ">=3.12" | requires-python = ">=3.12" | ||||||
| dependencies = [] | dependencies = [ | ||||||
|  |     "protothor>=0.1.1", | ||||||
|  |     "protothor-datarecords>=0.8.1", | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | [tool.uv.sources] | ||||||
|  | protothor = [{index = "protothor-gitlab"}] | ||||||
|  | protothor-datarecords = [{index = "protothor-datarecords-gitlab"}] | ||||||
|  |  | ||||||
|  | [[tool.uv.index]] | ||||||
|  | name = "protothor-datarecords-gitlab" | ||||||
|  | url = "https://git.hs-mittweida.de/api/v4/projects/2428/packages/pypi/simple" | ||||||
|  | default = false | ||||||
|  |  | ||||||
|  | [[tool.uv.index]] | ||||||
|  | name = "protothor-gitlab" | ||||||
|  | url = "https://git.hs-mittweida.de/api/v4/projects/2346/packages/pypi/simple" | ||||||
|  | default = false | ||||||
|  |  | ||||||
|  | [[tool.uv.index]] | ||||||
|  | name = "torch-cpu" | ||||||
|  | url = "https://download.pytorch.org/whl/cpu" | ||||||
|  | default = false | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user