feat: add GMLVQ with new architecture
This commit is contained in:
		@@ -2,11 +2,12 @@ import prototorch as pt
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
from prototorch.core import SMCI
 | 
					from prototorch.core import SMCI
 | 
				
			||||||
from prototorch.models.proto_y_architecture.callbacks import (
 | 
					from prototorch.models.y_arch.callbacks import (
 | 
				
			||||||
    LogTorchmetricCallback,
 | 
					    LogTorchmetricCallback,
 | 
				
			||||||
    VisGLVQ2D,
 | 
					    PlotLambdaMatrixToTensorboard,
 | 
				
			||||||
 | 
					    VisGMLVQ2D,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from prototorch.models.proto_y_architecture.glvq import GLVQ
 | 
					from prototorch.models.y_arch.library.gmlvq import GMLVQ
 | 
				
			||||||
from pytorch_lightning.callbacks import EarlyStopping
 | 
					from pytorch_lightning.callbacks import EarlyStopping
 | 
				
			||||||
from torch.utils.data import DataLoader
 | 
					from torch.utils.data import DataLoader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -19,8 +20,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # ------------------------------------------------------------
 | 
					    # ------------------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataset
 | 
					    # Dataset
 | 
				
			||||||
    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
					    train_ds = pt.datasets.Iris()
 | 
				
			||||||
    train_ds.targets[train_ds.targets == 2.0] = 1.0
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Dataloader
 | 
					    # Dataloader
 | 
				
			||||||
    train_loader = DataLoader(
 | 
					    train_loader = DataLoader(
 | 
				
			||||||
@@ -38,17 +38,19 @@ if __name__ == "__main__":
 | 
				
			|||||||
    components_initializer = SMCI(train_ds)
 | 
					    components_initializer = SMCI(train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Define Hyperparameters
 | 
					    # Define Hyperparameters
 | 
				
			||||||
    hyperparameters = GLVQ.HyperParameters(
 | 
					    hyperparameters = GMLVQ.HyperParameters(
 | 
				
			||||||
        lr=0.1,
 | 
					        lr=0.1,
 | 
				
			||||||
 | 
					        backbone_lr=5,
 | 
				
			||||||
 | 
					        input_dim=4,
 | 
				
			||||||
        distribution=dict(
 | 
					        distribution=dict(
 | 
				
			||||||
            num_classes=2,
 | 
					            num_classes=3,
 | 
				
			||||||
            per_class=1,
 | 
					            per_class=1,
 | 
				
			||||||
        ),
 | 
					        ),
 | 
				
			||||||
        component_initializer=components_initializer,
 | 
					        component_initializer=components_initializer,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Create Model
 | 
					    # Create Model
 | 
				
			||||||
    model = GLVQ(hyperparameters)
 | 
					    model = GMLVQ(hyperparameters)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print(model)
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -60,19 +62,17 @@ if __name__ == "__main__":
 | 
				
			|||||||
    stopping_criterion = LogTorchmetricCallback(
 | 
					    stopping_criterion = LogTorchmetricCallback(
 | 
				
			||||||
        'recall',
 | 
					        'recall',
 | 
				
			||||||
        torchmetrics.Recall,
 | 
					        torchmetrics.Recall,
 | 
				
			||||||
        num_classes=2,
 | 
					        num_classes=3,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    es = EarlyStopping(
 | 
					    es = EarlyStopping(
 | 
				
			||||||
        monitor=stopping_criterion.name,
 | 
					        monitor=stopping_criterion.name,
 | 
				
			||||||
        min_delta=0.001,
 | 
					 | 
				
			||||||
        patience=15,
 | 
					 | 
				
			||||||
        mode="max",
 | 
					        mode="max",
 | 
				
			||||||
        check_on_train_epoch_end=True,
 | 
					        patience=10,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Visualization Callback
 | 
					    # Visualization Callback
 | 
				
			||||||
    vis = VisGLVQ2D(data=train_ds)
 | 
					    vis = VisGMLVQ2D(data=train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Define trainer
 | 
					    # Define trainer
 | 
				
			||||||
    trainer = pl.Trainer(
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
@@ -80,10 +80,9 @@ if __name__ == "__main__":
 | 
				
			|||||||
            vis,
 | 
					            vis,
 | 
				
			||||||
            stopping_criterion,
 | 
					            stopping_criterion,
 | 
				
			||||||
            es,
 | 
					            es,
 | 
				
			||||||
 | 
					            PlotLambdaMatrixToTensorboard(),
 | 
				
			||||||
        ],
 | 
					        ],
 | 
				
			||||||
        gpus=0,
 | 
					        max_epochs=1000,
 | 
				
			||||||
        max_epochs=200,
 | 
					 | 
				
			||||||
        log_every_n_steps=1,
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Train
 | 
					    # Train
 | 
				
			||||||
@@ -1,63 +0,0 @@
 | 
				
			|||||||
from typing import Optional, Type
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import numpy as np
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
import torchmetrics
 | 
					 | 
				
			||||||
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
 | 
					 | 
				
			||||||
from prototorch.models.vis import Vis2DAbstract
 | 
					 | 
				
			||||||
from prototorch.utils.utils import mesh2d
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class LogTorchmetricCallback(pl.Callback):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        name,
 | 
					 | 
				
			||||||
        metric: Type[torchmetrics.Metric],
 | 
					 | 
				
			||||||
        on="prediction",
 | 
					 | 
				
			||||||
        **metric_kwargs,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        self.name = name
 | 
					 | 
				
			||||||
        self.metric = metric
 | 
					 | 
				
			||||||
        self.metric_kwargs = metric_kwargs
 | 
					 | 
				
			||||||
        self.on = on
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setup(
 | 
					 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        trainer: pl.Trainer,
 | 
					 | 
				
			||||||
        pl_module: BaseYArchitecture,
 | 
					 | 
				
			||||||
        stage: Optional[str] = None,
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					 | 
				
			||||||
        if self.on == "prediction":
 | 
					 | 
				
			||||||
            pl_module.register_torchmetric(
 | 
					 | 
				
			||||||
                self.name,
 | 
					 | 
				
			||||||
                self.metric,
 | 
					 | 
				
			||||||
                **self.metric_kwargs,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            raise ValueError(f"{self.on} is no valid metric hook")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class VisGLVQ2D(Vis2DAbstract):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def visualize(self, pl_module):
 | 
					 | 
				
			||||||
        protos = pl_module.prototypes
 | 
					 | 
				
			||||||
        plabels = pl_module.prototype_labels
 | 
					 | 
				
			||||||
        x_train, y_train = self.x_train, self.y_train
 | 
					 | 
				
			||||||
        ax = self.setup_ax()
 | 
					 | 
				
			||||||
        self.plot_protos(ax, protos, plabels)
 | 
					 | 
				
			||||||
        if x_train is not None:
 | 
					 | 
				
			||||||
            self.plot_data(ax, x_train, y_train)
 | 
					 | 
				
			||||||
            mesh_input, xx, yy = mesh2d(
 | 
					 | 
				
			||||||
                np.vstack([x_train, protos]),
 | 
					 | 
				
			||||||
                self.border,
 | 
					 | 
				
			||||||
                self.resolution,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
 | 
					 | 
				
			||||||
        _components = pl_module.components_layer.components
 | 
					 | 
				
			||||||
        mesh_input = torch.from_numpy(mesh_input).type_as(_components)
 | 
					 | 
				
			||||||
        y_pred = pl_module.predict(mesh_input)
 | 
					 | 
				
			||||||
        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
					 | 
				
			||||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
					 | 
				
			||||||
@@ -1,140 +0,0 @@
 | 
				
			|||||||
from dataclasses import dataclass, field
 | 
					 | 
				
			||||||
from typing import Callable, Type
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import torch
 | 
					 | 
				
			||||||
from prototorch.core.competitions import WTAC
 | 
					 | 
				
			||||||
from prototorch.core.components import LabeledComponents
 | 
					 | 
				
			||||||
from prototorch.core.distances import euclidean_distance
 | 
					 | 
				
			||||||
from prototorch.core.initializers import (
 | 
					 | 
				
			||||||
    AbstractComponentsInitializer,
 | 
					 | 
				
			||||||
    LabelsInitializer,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from prototorch.core.losses import GLVQLoss
 | 
					 | 
				
			||||||
from prototorch.models.proto_y_architecture.base import BaseYArchitecture
 | 
					 | 
				
			||||||
from prototorch.nn.wrappers import LambdaLayer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SupervisedArchitecture(BaseYArchitecture):
 | 
					 | 
				
			||||||
    components_layer: LabeledComponents
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @dataclass
 | 
					 | 
				
			||||||
    class HyperParameters:
 | 
					 | 
				
			||||||
        distribution: dict[str, int]
 | 
					 | 
				
			||||||
        component_initializer: AbstractComponentsInitializer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def init_components(self, hparams: HyperParameters):
 | 
					 | 
				
			||||||
        self.components_layer = LabeledComponents(
 | 
					 | 
				
			||||||
            distribution=hparams.distribution,
 | 
					 | 
				
			||||||
            components_initializer=hparams.component_initializer,
 | 
					 | 
				
			||||||
            labels_initializer=LabelsInitializer(),
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def prototypes(self):
 | 
					 | 
				
			||||||
        return self.components_layer.components.detach().cpu()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def prototype_labels(self):
 | 
					 | 
				
			||||||
        return self.components_layer.labels.detach().cpu()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class WTACompetitionMixin(BaseYArchitecture):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @dataclass
 | 
					 | 
				
			||||||
    class HyperParameters(BaseYArchitecture.HyperParameters):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def init_inference(self, hparams: HyperParameters):
 | 
					 | 
				
			||||||
        self.competition_layer = WTAC()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def inference(self, comparison_measures, components):
 | 
					 | 
				
			||||||
        comp_labels = components[1]
 | 
					 | 
				
			||||||
        return self.competition_layer(comparison_measures, comp_labels)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class GLVQLossMixin(BaseYArchitecture):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @dataclass
 | 
					 | 
				
			||||||
    class HyperParameters(BaseYArchitecture.HyperParameters):
 | 
					 | 
				
			||||||
        margin: float = 0.0
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        transfer_fn: str = "sigmoid_beta"
 | 
					 | 
				
			||||||
        transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def init_loss(self, hparams: HyperParameters):
 | 
					 | 
				
			||||||
        self.loss_layer = GLVQLoss(
 | 
					 | 
				
			||||||
            margin=hparams.margin,
 | 
					 | 
				
			||||||
            transfer_fn=hparams.transfer_fn,
 | 
					 | 
				
			||||||
            **hparams.transfer_args,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def loss(self, comparison_measures, batch, components):
 | 
					 | 
				
			||||||
        target = batch[1]
 | 
					 | 
				
			||||||
        comp_labels = components[1]
 | 
					 | 
				
			||||||
        loss = self.loss_layer(comparison_measures, target, comp_labels)
 | 
					 | 
				
			||||||
        self.log('loss', loss)
 | 
					 | 
				
			||||||
        return loss
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SingleLearningRateMixin(BaseYArchitecture):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @dataclass
 | 
					 | 
				
			||||||
    class HyperParameters(BaseYArchitecture.HyperParameters):
 | 
					 | 
				
			||||||
        # Training Hyperparameters
 | 
					 | 
				
			||||||
        lr: float = 0.01
 | 
					 | 
				
			||||||
        optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, hparams: HyperParameters) -> None:
 | 
					 | 
				
			||||||
        super().__init__(hparams)
 | 
					 | 
				
			||||||
        self.lr = hparams.lr
 | 
					 | 
				
			||||||
        self.optimizer = hparams.optimizer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def configure_optimizers(self):
 | 
					 | 
				
			||||||
        return self.optimizer(self.parameters(), lr=self.lr)  # type: ignore
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class SimpleComparisonMixin(BaseYArchitecture):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @dataclass
 | 
					 | 
				
			||||||
    class HyperParameters(BaseYArchitecture.HyperParameters):
 | 
					 | 
				
			||||||
        # Training Hyperparameters
 | 
					 | 
				
			||||||
        comparison_fn: Callable = euclidean_distance
 | 
					 | 
				
			||||||
        comparison_args: dict = field(default_factory=lambda: dict())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def init_comparison(self, hparams: HyperParameters):
 | 
					 | 
				
			||||||
        self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
 | 
					 | 
				
			||||||
                                            **hparams.comparison_args)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def comparison(self, batch, components):
 | 
					 | 
				
			||||||
        comp_tensor, _ = components
 | 
					 | 
				
			||||||
        batch_tensor, _ = batch
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        comp_tensor = comp_tensor.unsqueeze(1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        distances = self.comparison_layer(batch_tensor, comp_tensor)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return distances
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
# ##############################################################################
 | 
					 | 
				
			||||||
# GLVQ
 | 
					 | 
				
			||||||
# ##############################################################################
 | 
					 | 
				
			||||||
class GLVQ(
 | 
					 | 
				
			||||||
        SupervisedArchitecture,
 | 
					 | 
				
			||||||
        SimpleComparisonMixin,
 | 
					 | 
				
			||||||
        GLVQLossMixin,
 | 
					 | 
				
			||||||
        WTACompetitionMixin,
 | 
					 | 
				
			||||||
        SingleLearningRateMixin,
 | 
					 | 
				
			||||||
):
 | 
					 | 
				
			||||||
    """GLVQ using the new Scheme
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @dataclass
 | 
					 | 
				
			||||||
    class HyperParameters(
 | 
					 | 
				
			||||||
            SimpleComparisonMixin.HyperParameters,
 | 
					 | 
				
			||||||
            SingleLearningRateMixin.HyperParameters,
 | 
					 | 
				
			||||||
            GLVQLossMixin.HyperParameters,
 | 
					 | 
				
			||||||
            WTACompetitionMixin.HyperParameters,
 | 
					 | 
				
			||||||
            SupervisedArchitecture.HyperParameters,
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
							
								
								
									
										15
									
								
								prototorch/models/y_arch/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								prototorch/models/y_arch/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,15 @@
 | 
				
			|||||||
 | 
					from .architectures.base import BaseYArchitecture
 | 
				
			||||||
 | 
					from .architectures.comparison import SimpleComparisonMixin
 | 
				
			||||||
 | 
					from .architectures.competition import WTACompetitionMixin
 | 
				
			||||||
 | 
					from .architectures.components import SupervisedArchitecture
 | 
				
			||||||
 | 
					from .architectures.loss import GLVQLossMixin
 | 
				
			||||||
 | 
					from .architectures.optimization import SingleLearningRateMixin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    'BaseYArchitecture',
 | 
				
			||||||
 | 
					    "SimpleComparisonMixin",
 | 
				
			||||||
 | 
					    "SingleLearningRateMixin",
 | 
				
			||||||
 | 
					    "SupervisedArchitecture",
 | 
				
			||||||
 | 
					    "WTACompetitionMixin",
 | 
				
			||||||
 | 
					    "GLVQLossMixin",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
@@ -1,12 +1,7 @@
 | 
				
			|||||||
"""
 | 
					"""
 | 
				
			||||||
CLCC Scheme
 | 
					Proto Y Architecture
 | 
				
			||||||
 | 
					 | 
				
			||||||
CLCC is a LVQ scheme containing 4 steps
 | 
					 | 
				
			||||||
- Components
 | 
					 | 
				
			||||||
- Latent Space
 | 
					 | 
				
			||||||
- Comparison
 | 
					 | 
				
			||||||
- Competition
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Network architecture for Component based Learning.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
							
								
								
									
										41
									
								
								prototorch/models/y_arch/architectures/comparison.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										41
									
								
								prototorch/models/y_arch/architectures/comparison.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,41 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass, field
 | 
				
			||||||
 | 
					from typing import Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.core.distances import euclidean_distance
 | 
				
			||||||
 | 
					from prototorch.models.y_arch.architectures.base import BaseYArchitecture
 | 
				
			||||||
 | 
					from prototorch.nn.wrappers import LambdaLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SimpleComparisonMixin(BaseYArchitecture):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Simple Comparison
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    A comparison layer that only uses the positions of the components and the batch for dissimilarity computation.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # HyperParameters
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    @dataclass
 | 
				
			||||||
 | 
					    class HyperParameters(BaseYArchitecture.HyperParameters):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
 | 
				
			||||||
 | 
					        comparison_args: Keyword arguments for the comparison function. Default: {}.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        comparison_fn: Callable = euclidean_distance
 | 
				
			||||||
 | 
					        comparison_args: dict = field(default_factory=lambda: dict())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Steps
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    def init_comparison(self, hparams: HyperParameters):
 | 
				
			||||||
 | 
					        self.comparison_layer = LambdaLayer(fn=hparams.comparison_fn,
 | 
				
			||||||
 | 
					                                            **hparams.comparison_args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def comparison(self, batch, components):
 | 
				
			||||||
 | 
					        comp_tensor, _ = components
 | 
				
			||||||
 | 
					        batch_tensor, _ = batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        comp_tensor = comp_tensor.unsqueeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        distances = self.comparison_layer(batch_tensor, comp_tensor)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return distances
 | 
				
			||||||
							
								
								
									
										29
									
								
								prototorch/models/y_arch/architectures/competition.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								prototorch/models/y_arch/architectures/competition.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,29 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.core.competitions import WTAC
 | 
				
			||||||
 | 
					from prototorch.models.y_arch.architectures.base import BaseYArchitecture
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class WTACompetitionMixin(BaseYArchitecture):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Winner Take All Competition
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    A competition layer that uses the winner-take-all strategy.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # HyperParameters
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    @dataclass
 | 
				
			||||||
 | 
					    class HyperParameters(BaseYArchitecture.HyperParameters):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        No hyperparameters.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Steps
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    def init_inference(self, hparams: HyperParameters):
 | 
				
			||||||
 | 
					        self.competition_layer = WTAC()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def inference(self, comparison_measures, components):
 | 
				
			||||||
 | 
					        comp_labels = components[1]
 | 
				
			||||||
 | 
					        return self.competition_layer(comparison_measures, comp_labels)
 | 
				
			||||||
							
								
								
									
										53
									
								
								prototorch/models/y_arch/architectures/components.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								prototorch/models/y_arch/architectures/components.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.core.components import LabeledComponents
 | 
				
			||||||
 | 
					from prototorch.core.initializers import (
 | 
				
			||||||
 | 
					    AbstractComponentsInitializer,
 | 
				
			||||||
 | 
					    LabelsInitializer,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from prototorch.models.y_arch import BaseYArchitecture
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SupervisedArchitecture(BaseYArchitecture):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Supervised Architecture
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    An architecture that uses labeled Components as component Layer.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    components_layer: LabeledComponents
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # HyperParameters
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    @dataclass
 | 
				
			||||||
 | 
					    class HyperParameters:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        distribution: A valid prototype distribution. No default possible.
 | 
				
			||||||
 | 
					        components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        distribution: "dict[str, int]"
 | 
				
			||||||
 | 
					        component_initializer: AbstractComponentsInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Steps
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    def init_components(self, hparams: HyperParameters):
 | 
				
			||||||
 | 
					        self.components_layer = LabeledComponents(
 | 
				
			||||||
 | 
					            distribution=hparams.distribution,
 | 
				
			||||||
 | 
					            components_initializer=hparams.component_initializer,
 | 
				
			||||||
 | 
					            labels_initializer=LabelsInitializer(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Properties
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def prototypes(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Returns the position of the prototypes.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.components_layer.components.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def prototype_labels(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Returns the labels of the prototypes.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.components_layer.labels.detach().cpu()
 | 
				
			||||||
							
								
								
									
										42
									
								
								prototorch/models/y_arch/architectures/loss.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								prototorch/models/y_arch/architectures/loss.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,42 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass, field
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.core.losses import GLVQLoss
 | 
				
			||||||
 | 
					from prototorch.models.y_arch.architectures.base import BaseYArchitecture
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GLVQLossMixin(BaseYArchitecture):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    GLVQ Loss
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # HyperParameters
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    @dataclass
 | 
				
			||||||
 | 
					    class HyperParameters(BaseYArchitecture.HyperParameters):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        margin: The margin of the GLVQ loss. Default: 0.0.
 | 
				
			||||||
 | 
					        transfer_fn: Transfer function to use. Default: sigmoid_beta.
 | 
				
			||||||
 | 
					        transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        margin: float = 0.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        transfer_fn: str = "sigmoid_beta"
 | 
				
			||||||
 | 
					        transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Steps
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    def init_loss(self, hparams: HyperParameters):
 | 
				
			||||||
 | 
					        self.loss_layer = GLVQLoss(
 | 
				
			||||||
 | 
					            margin=hparams.margin,
 | 
				
			||||||
 | 
					            transfer_fn=hparams.transfer_fn,
 | 
				
			||||||
 | 
					            **hparams.transfer_args,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def loss(self, comparison_measures, batch, components):
 | 
				
			||||||
 | 
					        target = batch[1]
 | 
				
			||||||
 | 
					        comp_labels = components[1]
 | 
				
			||||||
 | 
					        loss = self.loss_layer(comparison_measures, target, comp_labels)
 | 
				
			||||||
 | 
					        self.log('loss', loss)
 | 
				
			||||||
 | 
					        return loss
 | 
				
			||||||
							
								
								
									
										36
									
								
								prototorch/models/y_arch/architectures/optimization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								prototorch/models/y_arch/architectures/optimization.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					from typing import Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.models.y_arch import BaseYArchitecture
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SingleLearningRateMixin(BaseYArchitecture):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Single Learning Rate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    All parameters are updated with a single learning rate.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # HyperParameters
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    @dataclass
 | 
				
			||||||
 | 
					    class HyperParameters(BaseYArchitecture.HyperParameters):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        lr: The learning rate. Default: 0.1.
 | 
				
			||||||
 | 
					        optimizer: The optimizer to use. Default: torch.optim.Adam.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        lr: float = 0.1
 | 
				
			||||||
 | 
					        optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Steps
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    def __init__(self, hparams: HyperParameters) -> None:
 | 
				
			||||||
 | 
					        super().__init__(hparams)
 | 
				
			||||||
 | 
					        self.lr = hparams.lr
 | 
				
			||||||
 | 
					        self.optimizer = hparams.optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Hooks
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    def configure_optimizers(self):
 | 
				
			||||||
 | 
					        return self.optimizer(self.parameters(), lr=self.lr)  # type: ignore
 | 
				
			||||||
							
								
								
									
										149
									
								
								prototorch/models/y_arch/callbacks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								prototorch/models/y_arch/callbacks.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,149 @@
 | 
				
			|||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					from typing import Optional, Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					from prototorch.models.vis import Vis2DAbstract
 | 
				
			||||||
 | 
					from prototorch.models.y_arch.architectures.base import BaseYArchitecture
 | 
				
			||||||
 | 
					from prototorch.models.y_arch.library.gmlvq import GMLVQ
 | 
				
			||||||
 | 
					from prototorch.utils.utils import mesh2d
 | 
				
			||||||
 | 
					from pytorch_lightning.loggers import TensorBoardLogger
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					DIVERGING_COLOR_MAPS = [
 | 
				
			||||||
 | 
					    'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn',
 | 
				
			||||||
 | 
					    'Spectral', 'coolwarm', 'bwr', 'seismic'
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LogTorchmetricCallback(pl.Callback):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        name,
 | 
				
			||||||
 | 
					        metric: Type[torchmetrics.Metric],
 | 
				
			||||||
 | 
					        on="prediction",
 | 
				
			||||||
 | 
					        **metric_kwargs,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        self.name = name
 | 
				
			||||||
 | 
					        self.metric = metric
 | 
				
			||||||
 | 
					        self.metric_kwargs = metric_kwargs
 | 
				
			||||||
 | 
					        self.on = on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setup(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        trainer: pl.Trainer,
 | 
				
			||||||
 | 
					        pl_module: BaseYArchitecture,
 | 
				
			||||||
 | 
					        stage: Optional[str] = None,
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
 | 
					        if self.on == "prediction":
 | 
				
			||||||
 | 
					            pl_module.register_torchmetric(
 | 
				
			||||||
 | 
					                self.name,
 | 
				
			||||||
 | 
					                self.metric,
 | 
				
			||||||
 | 
					                **self.metric_kwargs,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(f"{self.on} is no valid metric hook")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VisGLVQ2D(Vis2DAbstract):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def visualize(self, pl_module):
 | 
				
			||||||
 | 
					        protos = pl_module.prototypes
 | 
				
			||||||
 | 
					        plabels = pl_module.prototype_labels
 | 
				
			||||||
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
 | 
					        self.plot_protos(ax, protos, plabels)
 | 
				
			||||||
 | 
					        if x_train is not None:
 | 
				
			||||||
 | 
					            self.plot_data(ax, x_train, y_train)
 | 
				
			||||||
 | 
					            mesh_input, xx, yy = mesh2d(
 | 
				
			||||||
 | 
					                np.vstack([x_train, protos]),
 | 
				
			||||||
 | 
					                self.border,
 | 
				
			||||||
 | 
					                self.resolution,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
 | 
				
			||||||
 | 
					        _components = pl_module.components_layer.components
 | 
				
			||||||
 | 
					        mesh_input = torch.from_numpy(mesh_input).type_as(_components)
 | 
				
			||||||
 | 
					        y_pred = pl_module.predict(mesh_input)
 | 
				
			||||||
 | 
					        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
				
			||||||
 | 
					        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VisGMLVQ2D(Vis2DAbstract):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, *args, ev_proj=True, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					        self.ev_proj = ev_proj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def visualize(self, pl_module):
 | 
				
			||||||
 | 
					        protos = pl_module.prototypes
 | 
				
			||||||
 | 
					        plabels = pl_module.prototype_labels
 | 
				
			||||||
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
 | 
					        device = pl_module.device
 | 
				
			||||||
 | 
					        omega = pl_module._omega.detach()
 | 
				
			||||||
 | 
					        lam = omega @ omega.T
 | 
				
			||||||
 | 
					        u, _, _ = torch.pca_lowrank(lam, q=2)
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            x_train = torch.Tensor(x_train).to(device)
 | 
				
			||||||
 | 
					            x_train = x_train @ u
 | 
				
			||||||
 | 
					            x_train = x_train.cpu().detach()
 | 
				
			||||||
 | 
					        if self.show_protos:
 | 
				
			||||||
 | 
					            with torch.no_grad():
 | 
				
			||||||
 | 
					                protos = torch.Tensor(protos).to(device)
 | 
				
			||||||
 | 
					                protos = protos @ u
 | 
				
			||||||
 | 
					                protos = protos.cpu().detach()
 | 
				
			||||||
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
 | 
					        self.plot_data(ax, x_train, y_train)
 | 
				
			||||||
 | 
					        if self.show_protos:
 | 
				
			||||||
 | 
					            self.plot_protos(ax, protos, plabels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class PlotLambdaMatrixToTensorboard(pl.Callback):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, cmap='seismic') -> None:
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.cmap = cmap
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str:
 | 
				
			||||||
 | 
					            warnings.warn(
 | 
				
			||||||
 | 
					                f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_train_start(self, trainer, pl_module: GMLVQ):
 | 
				
			||||||
 | 
					        self.plot_lambda(trainer, pl_module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
 | 
				
			||||||
 | 
					        self.plot_lambda(trainer, pl_module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def plot_lambda(self, trainer, pl_module: GMLVQ):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.fig, self.ax = plt.subplots(1, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # plot lambda matrix
 | 
				
			||||||
 | 
					        l_matrix = pl_module.lambda_matrix
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # normalize lambda matrix
 | 
				
			||||||
 | 
					        l_matrix = l_matrix / torch.max(torch.abs(l_matrix))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # plot lambda matrix
 | 
				
			||||||
 | 
					        self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.fig.colorbar(self.ax.images[-1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # add title
 | 
				
			||||||
 | 
					        self.ax.set_title('Lambda Matrix')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # add to tensorboard
 | 
				
			||||||
 | 
					        if isinstance(trainer.logger, TensorBoardLogger):
 | 
				
			||||||
 | 
					            trainer.logger.experiment.add_figure(
 | 
				
			||||||
 | 
					                f"lambda_matrix",
 | 
				
			||||||
 | 
					                self.fig,
 | 
				
			||||||
 | 
					                trainer.global_step,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            warnings.warn(
 | 
				
			||||||
 | 
					                f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
							
								
								
									
										5
									
								
								prototorch/models/y_arch/library/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								prototorch/models/y_arch/library/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					from .glvq import GLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    "GLVQ",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
							
								
								
									
										35
									
								
								prototorch/models/y_arch/library/glvq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								prototorch/models/y_arch/library/glvq.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,35 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from prototorch.models.y_arch import (
 | 
				
			||||||
 | 
					    SimpleComparisonMixin,
 | 
				
			||||||
 | 
					    SingleLearningRateMixin,
 | 
				
			||||||
 | 
					    SupervisedArchitecture,
 | 
				
			||||||
 | 
					    WTACompetitionMixin,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from prototorch.models.y_arch.architectures.loss import GLVQLossMixin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GLVQ(
 | 
				
			||||||
 | 
					        SupervisedArchitecture,
 | 
				
			||||||
 | 
					        SimpleComparisonMixin,
 | 
				
			||||||
 | 
					        GLVQLossMixin,
 | 
				
			||||||
 | 
					        WTACompetitionMixin,
 | 
				
			||||||
 | 
					        SingleLearningRateMixin,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Generalized Learning Vector Quantization (GLVQ)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    A GLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @dataclass
 | 
				
			||||||
 | 
					    class HyperParameters(
 | 
				
			||||||
 | 
					            SimpleComparisonMixin.HyperParameters,
 | 
				
			||||||
 | 
					            SingleLearningRateMixin.HyperParameters,
 | 
				
			||||||
 | 
					            GLVQLossMixin.HyperParameters,
 | 
				
			||||||
 | 
					            WTACompetitionMixin.HyperParameters,
 | 
				
			||||||
 | 
					            SupervisedArchitecture.HyperParameters,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        No hyperparameters.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
							
								
								
									
										119
									
								
								prototorch/models/y_arch/library/gmlvq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								prototorch/models/y_arch/library/gmlvq.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,119 @@
 | 
				
			|||||||
 | 
					from __future__ import annotations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from dataclasses import dataclass, field
 | 
				
			||||||
 | 
					from typing import Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.core.distances import omega_distance
 | 
				
			||||||
 | 
					from prototorch.core.initializers import (
 | 
				
			||||||
 | 
					    AbstractLinearTransformInitializer,
 | 
				
			||||||
 | 
					    EyeLinearTransformInitializer,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from prototorch.models.y_arch import (
 | 
				
			||||||
 | 
					    GLVQLossMixin,
 | 
				
			||||||
 | 
					    SimpleComparisonMixin,
 | 
				
			||||||
 | 
					    SupervisedArchitecture,
 | 
				
			||||||
 | 
					    WTACompetitionMixin,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from prototorch.nn.wrappers import LambdaLayer
 | 
				
			||||||
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GMLVQ(
 | 
				
			||||||
 | 
					        SupervisedArchitecture,
 | 
				
			||||||
 | 
					        SimpleComparisonMixin,
 | 
				
			||||||
 | 
					        GLVQLossMixin,
 | 
				
			||||||
 | 
					        WTACompetitionMixin,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Generalized Matrix Learning Vector Quantization (GMLVQ)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    _omega: torch.Tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # HyperParameters
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    @dataclass
 | 
				
			||||||
 | 
					    class HyperParameters(
 | 
				
			||||||
 | 
					            SimpleComparisonMixin.HyperParameters,
 | 
				
			||||||
 | 
					            GLVQLossMixin.HyperParameters,
 | 
				
			||||||
 | 
					            WTACompetitionMixin.HyperParameters,
 | 
				
			||||||
 | 
					            SupervisedArchitecture.HyperParameters,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance.
 | 
				
			||||||
 | 
					        comparison_args: Keyword arguments for the comparison function. Override Default: {}.
 | 
				
			||||||
 | 
					        input_dim: Necessary Field: The dimensionality of the input.
 | 
				
			||||||
 | 
					        latent_dim: The dimensionality of the latent space. Default: 2.
 | 
				
			||||||
 | 
					        omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        backbone_lr: float = 0.1
 | 
				
			||||||
 | 
					        lr: float = 0.1
 | 
				
			||||||
 | 
					        comparison_fn: Callable = omega_distance
 | 
				
			||||||
 | 
					        comparison_args: dict = field(default_factory=lambda: dict())
 | 
				
			||||||
 | 
					        input_dim: int | None = None
 | 
				
			||||||
 | 
					        latent_dim: int = 2
 | 
				
			||||||
 | 
					        omega_initializer: type[
 | 
				
			||||||
 | 
					            AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Steps
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    def __init__(self, hparams) -> None:
 | 
				
			||||||
 | 
					        super().__init__(hparams)
 | 
				
			||||||
 | 
					        self.lr = hparams.lr
 | 
				
			||||||
 | 
					        self.backbone_lr = hparams.backbone_lr
 | 
				
			||||||
 | 
					        self.optimizer = hparams.optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_comparison(self, hparams: HyperParameters) -> None:
 | 
				
			||||||
 | 
					        if hparams.input_dim is None:
 | 
				
			||||||
 | 
					            raise ValueError("input_dim must be specified.")
 | 
				
			||||||
 | 
					        omega = hparams.omega_initializer().generate(
 | 
				
			||||||
 | 
					            hparams.input_dim,
 | 
				
			||||||
 | 
					            hparams.latent_dim,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.register_parameter("_omega", Parameter(omega))
 | 
				
			||||||
 | 
					        self.comparison_layer = LambdaLayer(
 | 
				
			||||||
 | 
					            fn=hparams.comparison_fn,
 | 
				
			||||||
 | 
					            **hparams.comparison_args,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def comparison(self, batch, components):
 | 
				
			||||||
 | 
					        comp_tensor, _ = components
 | 
				
			||||||
 | 
					        batch_tensor, _ = batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        comp_tensor = comp_tensor.unsqueeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        distances = self.comparison_layer(
 | 
				
			||||||
 | 
					            batch_tensor,
 | 
				
			||||||
 | 
					            comp_tensor,
 | 
				
			||||||
 | 
					            self._omega,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def configure_optimizers(self):
 | 
				
			||||||
 | 
					        proto_opt = self.optimizer(
 | 
				
			||||||
 | 
					            self.components_layer.parameters(),
 | 
				
			||||||
 | 
					            lr=self.lr,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        omega_opt = self.optimizer(
 | 
				
			||||||
 | 
					            [self._omega],
 | 
				
			||||||
 | 
					            lr=self.backbone_lr,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return [proto_opt, omega_opt]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Properties
 | 
				
			||||||
 | 
					    # ----------------------------------------------------------------------------------------------------
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def omega_matrix(self):
 | 
				
			||||||
 | 
					        return self._omega.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def lambda_matrix(self):
 | 
				
			||||||
 | 
					        omega = self._omega.detach()
 | 
				
			||||||
 | 
					        lam = omega @ omega.T
 | 
				
			||||||
 | 
					        return lam.detach().cpu()
 | 
				
			||||||
		Reference in New Issue
	
	Block a user