80 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			80 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Write a short description of your experiment here."""
 | |
| 
 | |
| import logging
 | |
| import lightning as L
 | |
| import torch
 | |
| from torch.utils.data import DataLoader, random_split
 | |
| 
 | |
| from protothor.callbacks.normalization import OmegaTraceNormalization
 | |
| from protothor.data.preprocessing import Standardizer
 | |
| from lightning.pytorch.callbacks import EarlyStopping
 | |
| from protothor.data.toy import Iris
 | |
| from protothor.functional.initialization import zero_initialization
 | |
| from protothor.lightning.metric_module import MetricModule
 | |
| from protothor.models.gmlvq import generate_model as generate_gmlvq
 | |
| from protothor.nn.collectable_loss import EvaluateExposedLosses
 | |
| from protothor.nn.container import find_instances
 | |
| 
 | |
| logging.basicConfig(level=logging.INFO)
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     # 1 - Get Dataset
 | |
|     data = Iris([0, 2])
 | |
|     train_ds, val_ds = random_split(data, [0.7, 0.3])
 | |
| 
 | |
|     standardize = Standardizer(train_ds)
 | |
| 
 | |
|     # 2 - Create Dataloaders
 | |
|     train_loader = DataLoader(
 | |
|         standardize(train_ds),
 | |
|         shuffle=True,
 | |
|         batch_size=len(train_ds),
 | |
|     )
 | |
|     val_loader = DataLoader(
 | |
|         standardize(val_ds),
 | |
|         shuffle=False,
 | |
|         batch_size=len(val_ds),
 | |
|     )
 | |
| 
 | |
|     # 3 - Initialize Prototypes
 | |
|     labels, positions = zero_initialization(data)
 | |
| 
 | |
|     # 4 - Generate Torch Model
 | |
|     model = generate_gmlvq(positions, labels)
 | |
|     omega_layer = find_instances(model, torch.nn.Linear)[0]
 | |
| 
 | |
|     # 5 - Initialize Lightning Module
 | |
|     module = MetricModule(model, EvaluateExposedLosses(model))
 | |
| 
 | |
|     # 6 - Define Callbacks
 | |
| 
 | |
|     callbacks = [
 | |
|         OmegaTraceNormalization(omega_layer),
 | |
|     ]
 | |
|     callbacks.append(
 | |
|         EarlyStopping(
 | |
|             monitor="accuracy/validation",
 | |
|             min_delta=0.0002,
 | |
|             patience=50,
 | |
|             mode="max",
 | |
|             verbose=False,
 | |
|             check_on_train_epoch_end=True,
 | |
|         )
 | |
|     )
 | |
| 
 | |
|     # 7 - Define Trainer
 | |
|     trainer = L.Trainer(detect_anomaly=True, max_epochs=1000, callbacks=callbacks)
 | |
| 
 | |
|     # 8 - Train Model
 | |
|     trainer.fit(module, train_loader, val_loader)
 | |
| 
 | |
|     # 9 - Analyse results
 | |
|     omega = omega_layer.weight.detach()
 | |
|     classification_correlation_matrix = omega.T @ omega
 | |
|     logging.info(classification_correlation_matrix)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 |