141 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			141 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from dataclasses import dataclass, field
 | |
| from typing import Callable, Type
 | |
| 
 | |
| import torch
 | |
| from prototorch.core.competitions import WTAC
 | |
| from prototorch.core.components import LabeledComponents
 | |
| from prototorch.core.distances import euclidean_distance
 | |
| from prototorch.core.initializers import (
 | |
|     AbstractComponentsInitializer,
 | |
|     LabelsInitializer,
 | |
| )
 | |
| from prototorch.core.losses import GLVQLoss
 | |
| from prototorch.models.proto_y_architecture.base import BaseYArchitecture
 | |
| from prototorch.nn.wrappers import LambdaLayer
 | |
| 
 | |
| 
 | |
| class SupervisedArchitecture(BaseYArchitecture):
 | |
|     components_layer: LabeledComponents
 | |
| 
 | |
|     @dataclass
 | |
|     class HyperParameters:
 | |
|         distribution: dict[str, int]
 | |
|         component_initializer: AbstractComponentsInitializer
 | |
| 
 | |
|     def init_components(self, hparams: HyperParameters):
 | |
|         self.components_layer = LabeledComponents(
 | |
|             distribution=hparams.distribution,
 | |
|             components_initializer=hparams.component_initializer,
 | |
|             labels_initializer=LabelsInitializer(),
 | |
|         )
 | |
| 
 | |
|     @property
 | |
|     def prototypes(self):
 | |
|         return self.components_layer.components.detach().cpu()
 | |
| 
 | |
|     @property
 | |
|     def prototype_labels(self):
 | |
|         return self.components_layer.labels.detach().cpu()
 | |
| 
 | |
| 
 | |
| class WTACompetitionMixin(BaseYArchitecture):
 | |
| 
 | |
|     @dataclass
 | |
|     class HyperParameters(BaseYArchitecture.HyperParameters):
 | |
|         pass
 | |
| 
 | |
|     def init_inference(self, hparams: HyperParameters):
 | |
|         self.competition_layer = WTAC()
 | |
| 
 | |
|     def inference(self, comparison_measures, components):
 | |
|         comp_labels = components[1]
 | |
|         return self.competition_layer(comparison_measures, comp_labels)
 | |
| 
 | |
| 
 | |
| class GLVQLossMixin(BaseYArchitecture):
 | |
| 
 | |
|     @dataclass
 | |
|     class HyperParameters(BaseYArchitecture.HyperParameters):
 | |
|         margin: float = 0.0
 | |
| 
 | |
|         transfer_fn: str = "sigmoid_beta"
 | |
|         transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
 | |
| 
 | |
|     def init_loss(self, hparams: HyperParameters):
 | |
|         self.loss_layer = GLVQLoss(
 | |
|             margin=hparams.margin,
 | |
|             transfer_fn=hparams.transfer_fn,
 | |
|             **hparams.transfer_args,
 | |
|         )
 | |
| 
 | |
|     def loss(self, comparison_measures, batch, components):
 | |
|         target = batch[1]
 | |
|         comp_labels = components[1]
 | |
|         loss = self.loss_layer(comparison_measures, target, comp_labels)
 | |
|         self.log('loss', loss)
 | |
|         return loss
 | |
| 
 | |
| 
 | |
| class SingleLearningRateMixin(BaseYArchitecture):
 | |
| 
 | |
|     @dataclass
 | |
|     class HyperParameters(BaseYArchitecture.HyperParameters):
 | |
|         # Training Hyperparameters
 | |
|         lr: float = 0.01
 | |
|         optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
 | |
| 
 | |
|     def __init__(self, hparams: HyperParameters) -> None:
 | |
|         super().__init__(hparams)
 | |
|         self.lr = hparams.lr
 | |
|         self.optimizer = hparams.optimizer
 | |
| 
 | |
|     def configure_optimizers(self):
 | |
|         return self.optimizer(self.parameters(), lr=self.lr)  # type: ignore
 | |
| 
 | |
| 
 | |
| class SimpleComparisonMixin(BaseYArchitecture):
 | |
| 
 | |
|     @dataclass
 | |
|     class HyperParameters(BaseYArchitecture.HyperParameters):
 | |
|         # Training Hyperparameters
 | |
|         comparison_fn: Callable = euclidean_distance
 | |
|         comparison_args: dict = field(default_factory=lambda: dict())
 | |
| 
 | |
|     def init_comparison(self, hparams: HyperParameters):
 | |
|         self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
 | |
|                                             **hparams.comparison_args)
 | |
| 
 | |
|     def comparison(self, batch, components):
 | |
|         comp_tensor, _ = components
 | |
|         batch_tensor, _ = batch
 | |
| 
 | |
|         comp_tensor = comp_tensor.unsqueeze(1)
 | |
| 
 | |
|         distances = self.comparison_layer(batch_tensor, comp_tensor)
 | |
| 
 | |
|         return distances
 | |
| 
 | |
| 
 | |
| # ##############################################################################
 | |
| # GLVQ
 | |
| # ##############################################################################
 | |
| class GLVQ(
 | |
|         SupervisedArchitecture,
 | |
|         SimpleComparisonMixin,
 | |
|         GLVQLossMixin,
 | |
|         WTACompetitionMixin,
 | |
|         SingleLearningRateMixin,
 | |
| ):
 | |
|     """GLVQ using the new Scheme
 | |
|     """
 | |
| 
 | |
|     @dataclass
 | |
|     class HyperParameters(
 | |
|             SimpleComparisonMixin.HyperParameters,
 | |
|             SingleLearningRateMixin.HyperParameters,
 | |
|             GLVQLossMixin.HyperParameters,
 | |
|             WTACompetitionMixin.HyperParameters,
 | |
|             SupervisedArchitecture.HyperParameters,
 | |
|     ):
 | |
|         pass
 |