Compare commits
	
		
			6 Commits
		
	
	
		
			main
			...
			feature/ux
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					5ce326ce62 | ||
| 
						 | 
					d1985571b3 | ||
| 
						 | 
					967953442b | ||
| 
						 | 
					d4448f2bc9 | ||
| 
						 | 
					a8829945f5 | ||
| 
						 | 
					a8336ee213 | 
@@ -18,12 +18,12 @@ repos:
 | 
				
			|||||||
  - id: autoflake
 | 
					  - id: autoflake
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- repo: http://github.com/PyCQA/isort
 | 
					- repo: http://github.com/PyCQA/isort
 | 
				
			||||||
  rev: 5.8.0
 | 
					  rev: 5.9.3
 | 
				
			||||||
  hooks:
 | 
					  hooks:
 | 
				
			||||||
  - id: isort
 | 
					  - id: isort
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- repo: https://github.com/pre-commit/mirrors-mypy
 | 
					- repo: https://github.com/pre-commit/mirrors-mypy
 | 
				
			||||||
  rev: v0.902
 | 
					  rev: v0.910-1
 | 
				
			||||||
  hooks:
 | 
					  hooks:
 | 
				
			||||||
  - id: mypy
 | 
					  - id: mypy
 | 
				
			||||||
    files: prototorch
 | 
					    files: prototorch
 | 
				
			||||||
@@ -42,9 +42,10 @@ repos:
 | 
				
			|||||||
  - id: python-check-blanket-noqa
 | 
					  - id: python-check-blanket-noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- repo: https://github.com/asottile/pyupgrade
 | 
					- repo: https://github.com/asottile/pyupgrade
 | 
				
			||||||
  rev: v2.19.4
 | 
					  rev: v2.29.0
 | 
				
			||||||
  hooks:
 | 
					  hooks:
 | 
				
			||||||
  - id: pyupgrade
 | 
					  - id: pyupgrade
 | 
				
			||||||
 | 
					    args: [--py36-plus]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- repo: https://github.com/si-cim/gitlint
 | 
					- repo: https://github.com/si-cim/gitlint
 | 
				
			||||||
  rev: v0.15.2-unofficial
 | 
					  rev: v0.15.2-unofficial
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -38,10 +38,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = pt.models.VisCBC2D(data=train_ds,
 | 
					    vis = pt.models.Visualize2DVoronoiCallback(
 | 
				
			||||||
                             title="CBC Iris Example",
 | 
					        data=train_ds,
 | 
				
			||||||
                             resolution=100,
 | 
					        title="CBC Iris Example",
 | 
				
			||||||
                             axis_off=True)
 | 
					        resolution=100,
 | 
				
			||||||
 | 
					        axis_off=True,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@
 | 
				
			|||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					import prototorch.models.clcc
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
@@ -29,7 +30,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize the model
 | 
					    # Initialize the model
 | 
				
			||||||
    model = pt.models.GLVQ(
 | 
					    model = prototorch.models.GLVQ(
 | 
				
			||||||
        hparams,
 | 
					        hparams,
 | 
				
			||||||
        optimizer=torch.optim.Adam,
 | 
					        optimizer=torch.optim.Adam,
 | 
				
			||||||
        prototypes_initializer=pt.initializers.SMCI(train_ds),
 | 
					        prototypes_initializer=pt.initializers.SMCI(train_ds),
 | 
				
			||||||
@@ -41,7 +42,13 @@ if __name__ == "__main__":
 | 
				
			|||||||
    model.example_input_array = torch.zeros(4, 2)
 | 
					    model.example_input_array = torch.zeros(4, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Callbacks
 | 
					    # Callbacks
 | 
				
			||||||
    vis = pt.models.VisGLVQ2D(data=train_ds)
 | 
					    vis = pt.models.Visualize2DVoronoiCallback(
 | 
				
			||||||
 | 
					        data=train_ds,
 | 
				
			||||||
 | 
					        resolution=200,
 | 
				
			||||||
 | 
					        title="Example: GLVQ on Iris",
 | 
				
			||||||
 | 
					        x_label="sepal length",
 | 
				
			||||||
 | 
					        y_label="petal length",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,13 +3,12 @@
 | 
				
			|||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					from prototorch.core.competitions import WTAC
 | 
				
			||||||
from ..core.competitions import WTAC
 | 
					from prototorch.core.components import Components, LabeledComponents
 | 
				
			||||||
from ..core.components import Components, LabeledComponents
 | 
					from prototorch.core.distances import euclidean_distance
 | 
				
			||||||
from ..core.distances import euclidean_distance
 | 
					from prototorch.core.initializers import LabelsInitializer
 | 
				
			||||||
from ..core.initializers import LabelsInitializer
 | 
					from prototorch.core.pooling import stratified_min_pooling
 | 
				
			||||||
from ..core.pooling import stratified_min_pooling
 | 
					from prototorch.nn.wrappers import LambdaLayer
 | 
				
			||||||
from ..nn.wrappers import LambdaLayer
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ProtoTorchBolt(pl.LightningModule):
 | 
					class ProtoTorchBolt(pl.LightningModule):
 | 
				
			||||||
@@ -169,32 +168,3 @@ class SupervisedPrototypeModel(PrototypeModel):
 | 
				
			|||||||
        accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
 | 
					        accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.log("test_acc", accuracy)
 | 
					        self.log("test_acc", accuracy)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ProtoTorchMixin(object):
 | 
					 | 
				
			||||||
    """All mixins are ProtoTorchMixins."""
 | 
					 | 
				
			||||||
    pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class NonGradientMixin(ProtoTorchMixin):
 | 
					 | 
				
			||||||
    """Mixin for custom non-gradient optimization."""
 | 
					 | 
				
			||||||
    def __init__(self, *args, **kwargs):
 | 
					 | 
				
			||||||
        super().__init__(*args, **kwargs)
 | 
					 | 
				
			||||||
        self.automatic_optimization = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
					 | 
				
			||||||
        raise NotImplementedError
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ImagePrototypesMixin(ProtoTorchMixin):
 | 
					 | 
				
			||||||
    """Mixin for models with image prototypes."""
 | 
					 | 
				
			||||||
    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
					 | 
				
			||||||
        """Constrain the components to the range [0, 1] by clamping after updates."""
 | 
					 | 
				
			||||||
        self.proto_layer.components.data.clamp_(0.0, 1.0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def get_prototype_grid(self, num_columns=2, return_channels_last=True):
 | 
					 | 
				
			||||||
        from torchvision.utils import make_grid
 | 
					 | 
				
			||||||
        grid = make_grid(self.components, nrow=num_columns)
 | 
					 | 
				
			||||||
        if return_channels_last:
 | 
					 | 
				
			||||||
            grid = grid.permute((1, 2, 0))
 | 
					 | 
				
			||||||
        return grid.cpu()
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,9 +4,9 @@ import logging
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.core.components import Components
 | 
				
			||||||
 | 
					from prototorch.core.initializers import LiteralCompInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.components import Components
 | 
					 | 
				
			||||||
from ..core.initializers import LiteralCompInitializer
 | 
					 | 
				
			||||||
from .extras import ConnectionTopology
 | 
					from .extras import ConnectionTopology
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,14 +1,14 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					from prototorch.core.competitions import CBCC
 | 
				
			||||||
 | 
					from prototorch.core.components import ReasoningComponents
 | 
				
			||||||
 | 
					from prototorch.core.initializers import RandomReasoningsInitializer
 | 
				
			||||||
 | 
					from prototorch.core.losses import MarginLoss
 | 
				
			||||||
 | 
					from prototorch.core.similarities import euclidean_similarity
 | 
				
			||||||
 | 
					from prototorch.nn.wrappers import LambdaLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.competitions import CBCC
 | 
					 | 
				
			||||||
from ..core.components import ReasoningComponents
 | 
					 | 
				
			||||||
from ..core.initializers import RandomReasoningsInitializer
 | 
					 | 
				
			||||||
from ..core.losses import MarginLoss
 | 
					 | 
				
			||||||
from ..core.similarities import euclidean_similarity
 | 
					 | 
				
			||||||
from ..nn.wrappers import LambdaLayer
 | 
					 | 
				
			||||||
from .abstract import ImagePrototypesMixin
 | 
					 | 
				
			||||||
from .glvq import SiameseGLVQ
 | 
					from .glvq import SiameseGLVQ
 | 
				
			||||||
 | 
					from .mixin import ImagePrototypesMixin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CBC(SiameseGLVQ):
 | 
					class CBC(SiameseGLVQ):
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										0
									
								
								prototorch/models/clcc/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								prototorch/models/clcc/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										86
									
								
								prototorch/models/clcc/clcc_glvq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								prototorch/models/clcc/clcc_glvq.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					from typing import Callable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.core.competitions import WTAC
 | 
				
			||||||
 | 
					from prototorch.core.components import LabeledComponents
 | 
				
			||||||
 | 
					from prototorch.core.distances import euclidean_distance
 | 
				
			||||||
 | 
					from prototorch.core.initializers import AbstractComponentsInitializer, LabelsInitializer
 | 
				
			||||||
 | 
					from prototorch.core.losses import GLVQLoss
 | 
				
			||||||
 | 
					from prototorch.models.clcc.clcc_scheme import CLCCScheme
 | 
				
			||||||
 | 
					from prototorch.nn.wrappers import LambdaLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class GLVQhparams:
 | 
				
			||||||
 | 
					    distribution: dict
 | 
				
			||||||
 | 
					    component_initializer: AbstractComponentsInitializer
 | 
				
			||||||
 | 
					    distance_fn: Callable = euclidean_distance
 | 
				
			||||||
 | 
					    lr: float = 0.01
 | 
				
			||||||
 | 
					    margin: float = 0.0
 | 
				
			||||||
 | 
					    # TODO: make nicer
 | 
				
			||||||
 | 
					    transfer_fn: str = "identity"
 | 
				
			||||||
 | 
					    transfer_beta: float = 10.0
 | 
				
			||||||
 | 
					    optimizer: torch.optim.Optimizer = torch.optim.Adam
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GLVQ(CLCCScheme):
 | 
				
			||||||
 | 
					    def __init__(self, hparams: GLVQhparams) -> None:
 | 
				
			||||||
 | 
					        super().__init__(hparams)
 | 
				
			||||||
 | 
					        self.lr = hparams.lr
 | 
				
			||||||
 | 
					        self.optimizer = hparams.optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initializers
 | 
				
			||||||
 | 
					    def init_components(self, hparams):
 | 
				
			||||||
 | 
					        # initialize Component Layer
 | 
				
			||||||
 | 
					        self.components_layer = LabeledComponents(
 | 
				
			||||||
 | 
					            distribution=hparams.distribution,
 | 
				
			||||||
 | 
					            components_initializer=hparams.component_initializer,
 | 
				
			||||||
 | 
					            labels_initializer=LabelsInitializer(),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_comparison(self, hparams):
 | 
				
			||||||
 | 
					        # initialize Distance Layer
 | 
				
			||||||
 | 
					        self.comparison_layer = LambdaLayer(hparams.distance_fn)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_inference(self, hparams):
 | 
				
			||||||
 | 
					        self.competition_layer = WTAC()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_loss(self, hparams):
 | 
				
			||||||
 | 
					        self.loss_layer = GLVQLoss(
 | 
				
			||||||
 | 
					            margin=hparams.margin,
 | 
				
			||||||
 | 
					            transfer_fn=hparams.transfer_fn,
 | 
				
			||||||
 | 
					            beta=hparams.transfer_beta,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Steps
 | 
				
			||||||
 | 
					    def comparison(self, batch, components):
 | 
				
			||||||
 | 
					        comp_tensor, _ = components
 | 
				
			||||||
 | 
					        batch_tensor, _ = batch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        comp_tensor = comp_tensor.unsqueeze(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        distances = self.comparison_layer(batch_tensor, comp_tensor)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return distances
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def inference(self, comparisonmeasures, components):
 | 
				
			||||||
 | 
					        comp_labels = components[1]
 | 
				
			||||||
 | 
					        return self.competition_layer(comparisonmeasures, comp_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def loss(self, comparisonmeasures, batch, components):
 | 
				
			||||||
 | 
					        target = batch[1]
 | 
				
			||||||
 | 
					        comp_labels = components[1]
 | 
				
			||||||
 | 
					        return self.loss_layer(comparisonmeasures, target, comp_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def configure_optimizers(self):
 | 
				
			||||||
 | 
					        return self.optimizer(self.parameters(), lr=self.lr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Properties
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def prototypes(self):
 | 
				
			||||||
 | 
					        return self.components_layer.components.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def prototype_labels(self):
 | 
				
			||||||
 | 
					        return self.components_layer.labels.detach().cpu()
 | 
				
			||||||
							
								
								
									
										192
									
								
								prototorch/models/clcc/clcc_scheme.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										192
									
								
								prototorch/models/clcc/clcc_scheme.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,192 @@
 | 
				
			|||||||
 | 
					"""
 | 
				
			||||||
 | 
					CLCC Scheme
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CLCC is a LVQ scheme containing 4 steps
 | 
				
			||||||
 | 
					- Components
 | 
				
			||||||
 | 
					- Latent Space
 | 
				
			||||||
 | 
					- Comparison
 | 
				
			||||||
 | 
					- Competition
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					from typing import Dict, Set, Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CLCCScheme(pl.LightningModule):
 | 
				
			||||||
 | 
					    registered_metrics: Dict[Type[torchmetrics.Metric],
 | 
				
			||||||
 | 
					                             torchmetrics.Metric] = {}
 | 
				
			||||||
 | 
					    registered_metric_names: Dict[Type[torchmetrics.Metric], Set[str]] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, hparams) -> None:
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Common Steps
 | 
				
			||||||
 | 
					        self.init_components(hparams)
 | 
				
			||||||
 | 
					        self.init_latent(hparams)
 | 
				
			||||||
 | 
					        self.init_comparison(hparams)
 | 
				
			||||||
 | 
					        self.init_competition(hparams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Train Steps
 | 
				
			||||||
 | 
					        self.init_loss(hparams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Inference Steps
 | 
				
			||||||
 | 
					        self.init_inference(hparams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Initialize Model Metrics
 | 
				
			||||||
 | 
					        self.init_model_metrics()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # internal API, called by models and callbacks
 | 
				
			||||||
 | 
					    def register_torchmetric(self, name: str, metric: torchmetrics.Metric):
 | 
				
			||||||
 | 
					        if metric not in self.registered_metrics:
 | 
				
			||||||
 | 
					            self.registered_metrics[metric] = metric()
 | 
				
			||||||
 | 
					            self.registered_metric_names[metric] = {name}
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.registered_metric_names[metric].add(name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # external API
 | 
				
			||||||
 | 
					    def get_competion(self, batch, components):
 | 
				
			||||||
 | 
					        latent_batch, latent_components = self.latent(batch, components)
 | 
				
			||||||
 | 
					        # TODO: => Latent Hook
 | 
				
			||||||
 | 
					        comparison_tensor = self.comparison(latent_batch, latent_components)
 | 
				
			||||||
 | 
					        # TODO: => Comparison Hook
 | 
				
			||||||
 | 
					        return comparison_tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, batch):
 | 
				
			||||||
 | 
					        if isinstance(batch, torch.Tensor):
 | 
				
			||||||
 | 
					            batch = (batch, None)
 | 
				
			||||||
 | 
					        # TODO: manage different datatypes?
 | 
				
			||||||
 | 
					        components = self.components_layer()
 | 
				
			||||||
 | 
					        # TODO: => Component Hook
 | 
				
			||||||
 | 
					        comparison_tensor = self.get_competion(batch, components)
 | 
				
			||||||
 | 
					        # TODO: => Competition Hook
 | 
				
			||||||
 | 
					        return self.inference(comparison_tensor, components)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def predict(self, batch):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Alias for forward
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return self.forward(batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def loss_forward(self, batch):
 | 
				
			||||||
 | 
					        # TODO: manage different datatypes?
 | 
				
			||||||
 | 
					        components = self.components_layer()
 | 
				
			||||||
 | 
					        # TODO: => Component Hook
 | 
				
			||||||
 | 
					        comparison_tensor = self.get_competion(batch, components)
 | 
				
			||||||
 | 
					        # TODO: => Competition Hook
 | 
				
			||||||
 | 
					        return self.loss(comparison_tensor, batch, components)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Empty Initialization
 | 
				
			||||||
 | 
					    # TODO: Type hints
 | 
				
			||||||
 | 
					    # TODO: Docs
 | 
				
			||||||
 | 
					    def init_components(self, hparams):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_latent(self, hparams):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_comparison(self, hparams):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_competition(self, hparams):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_loss(self, hparams):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_inference(self, hparams):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_model_metrics(self):
 | 
				
			||||||
 | 
					        self.register_torchmetric('train_accuracy', torchmetrics.Accuracy)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Empty Steps
 | 
				
			||||||
 | 
					    # TODO: Type hints
 | 
				
			||||||
 | 
					    def components(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        This step has no input.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        It returns the components.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError(
 | 
				
			||||||
 | 
					            "The components step has no reasonable default.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def latent(self, batch, components):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        The latent step receives the data batch and the components.
 | 
				
			||||||
 | 
					        It can transform both by an arbitrary function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        It returns the transformed batch and components, each of the same length as the original input.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        return batch, components
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def comparison(self, batch, components):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Takes a batch of size N and the componentsset of size M.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError(
 | 
				
			||||||
 | 
					            "The comparison step has no reasonable default.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def competition(self, comparisonmeasures, components):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Takes the tensor of comparison measures.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Assigns a competition vector to each class.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError(
 | 
				
			||||||
 | 
					            "The competition step has no reasonable default.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def loss(self, comparisonmeasures, batch, components):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Takes the tensor of competition measures.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Calculates a single loss value
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError("The loss step has no reasonable default.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def inference(self, comparisonmeasures, components):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Takes the tensor of competition measures.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns the inferred vector.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError(
 | 
				
			||||||
 | 
					            "The inference step has no reasonable default.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_metrics_step(self, batch):
 | 
				
			||||||
 | 
					        x, y = batch
 | 
				
			||||||
 | 
					        preds = self(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for metric in self.registered_metrics:
 | 
				
			||||||
 | 
					            instance = self.registered_metrics[metric].to(self.device)
 | 
				
			||||||
 | 
					            value = instance(y, preds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for name in self.registered_metric_names[metric]:
 | 
				
			||||||
 | 
					                self.log(name, value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_metrics_epoch(self):
 | 
				
			||||||
 | 
					        for metric in self.registered_metrics:
 | 
				
			||||||
 | 
					            instance = self.registered_metrics[metric].to(self.device)
 | 
				
			||||||
 | 
					            value = instance.compute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for name in self.registered_metric_names[metric]:
 | 
				
			||||||
 | 
					                self.log(name, value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Lightning Hooks
 | 
				
			||||||
 | 
					    def training_step(self, batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
 | 
					        self.update_metrics_step(batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return self.loss_forward(batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def train_epoch_end(self, outs) -> None:
 | 
				
			||||||
 | 
					        self.update_metrics_epoch()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def validation_step(self, batch, batch_idx):
 | 
				
			||||||
 | 
					        return self.loss_forward(batch)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_step(self, batch, batch_idx):
 | 
				
			||||||
 | 
					        return self.loss_forward(batch)
 | 
				
			||||||
							
								
								
									
										76
									
								
								prototorch/models/clcc/test_clcc.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								prototorch/models/clcc/test_clcc.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,76 @@
 | 
				
			|||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					import torchmetrics
 | 
				
			||||||
 | 
					from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
 | 
				
			||||||
 | 
					from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams
 | 
				
			||||||
 | 
					from prototorch.models.clcc.clcc_scheme import CLCCScheme
 | 
				
			||||||
 | 
					from prototorch.models.vis import Visualize2DVoronoiCallback
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# NEW STUFF
 | 
				
			||||||
 | 
					# ##############################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO: Metrics
 | 
				
			||||||
 | 
					class MetricsTestCallback(pl.Callback):
 | 
				
			||||||
 | 
					    metric_name = "test_cb_acc"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setup(self,
 | 
				
			||||||
 | 
					              trainer: pl.Trainer,
 | 
				
			||||||
 | 
					              pl_module: CLCCScheme,
 | 
				
			||||||
 | 
					              stage: Optional[str] = None) -> None:
 | 
				
			||||||
 | 
					        pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_epoch_end(self, trainer: pl.Trainer,
 | 
				
			||||||
 | 
					                     pl_module: pl.LightningModule) -> None:
 | 
				
			||||||
 | 
					        metric = trainer.logged_metrics[self.metric_name]
 | 
				
			||||||
 | 
					        if metric > 0.95:
 | 
				
			||||||
 | 
					            trainer.should_stop = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO: Pruning
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ##############################################################################
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Dataset
 | 
				
			||||||
 | 
					    train_ds = pt.datasets.Iris(dims=[0, 2])
 | 
				
			||||||
 | 
					    # Dataloaders
 | 
				
			||||||
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds,
 | 
				
			||||||
 | 
					                                               batch_size=64,
 | 
				
			||||||
 | 
					                                               num_workers=8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    components_initializer = SMCI(train_ds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hparams = GLVQhparams(
 | 
				
			||||||
 | 
					        distribution=dict(
 | 
				
			||||||
 | 
					            num_classes=3,
 | 
				
			||||||
 | 
					            per_class=2,
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        component_initializer=components_initializer,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    model = GLVQ(hparams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    vis = Visualize2DVoronoiCallback(
 | 
				
			||||||
 | 
					        data=train_ds,
 | 
				
			||||||
 | 
					        resolution=500,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    metrics = MetricsTestCallback()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Train
 | 
				
			||||||
 | 
					    trainer = pl.Trainer(
 | 
				
			||||||
 | 
					        callbacks=[
 | 
				
			||||||
 | 
					            #vis,
 | 
				
			||||||
 | 
					            metrics,
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        gpus=1,
 | 
				
			||||||
 | 
					        max_epochs=100,
 | 
				
			||||||
 | 
					        weights_summary=None,
 | 
				
			||||||
 | 
					        log_every_n_steps=1,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
@@ -5,8 +5,7 @@ Modules not yet available in prototorch go here temporarily.
 | 
				
			|||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.core.similarities import gaussian
 | 
				
			||||||
from ..core.similarities import gaussian
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rank_scaled_gaussian(distances, lambd):
 | 
					def rank_scaled_gaussian(distances, lambd):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,15 +1,16 @@
 | 
				
			|||||||
"""Models based on the GLVQ framework."""
 | 
					"""Models based on the GLVQ framework."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.core.competitions import wtac
 | 
				
			||||||
 | 
					from prototorch.core.distances import lomega_distance, omega_distance, squared_euclidean_distance
 | 
				
			||||||
 | 
					from prototorch.core.initializers import EyeTransformInitializer
 | 
				
			||||||
 | 
					from prototorch.core.losses import GLVQLoss, lvq1_loss, lvq21_loss
 | 
				
			||||||
 | 
					from prototorch.core.transforms import LinearTransform
 | 
				
			||||||
 | 
					from prototorch.nn.wrappers import LambdaLayer, LossLayer
 | 
				
			||||||
from torch.nn.parameter import Parameter
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.competitions import wtac
 | 
					from .abstract import SupervisedPrototypeModel
 | 
				
			||||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
 | 
					from .mixin import ImagePrototypesMixin
 | 
				
			||||||
from ..core.initializers import EyeTransformInitializer
 | 
					 | 
				
			||||||
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
 | 
					 | 
				
			||||||
from ..core.transforms import LinearTransform
 | 
					 | 
				
			||||||
from ..nn.wrappers import LambdaLayer, LossLayer
 | 
					 | 
				
			||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class GLVQ(SupervisedPrototypeModel):
 | 
					class GLVQ(SupervisedPrototypeModel):
 | 
				
			||||||
@@ -130,7 +131,7 @@ class SiameseGLVQ(GLVQ):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def compute_distances(self, x):
 | 
					    def compute_distances(self, x):
 | 
				
			||||||
        protos, _ = self.proto_layer()
 | 
					        protos, _ = self.proto_layer()
 | 
				
			||||||
        x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
 | 
					        x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
 | 
				
			||||||
        latent_x = self.backbone(x)
 | 
					        latent_x = self.backbone(x)
 | 
				
			||||||
        self.backbone.requires_grad_(self.both_path_gradients)
 | 
					        self.backbone.requires_grad_(self.both_path_gradients)
 | 
				
			||||||
        latent_protos = self.backbone(protos)
 | 
					        latent_protos = self.backbone(protos)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,10 +2,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.competitions import KNNC
 | 
					from prototorch.core.competitions import KNNC
 | 
				
			||||||
from ..core.components import LabeledComponents
 | 
					from prototorch.core.components import LabeledComponents
 | 
				
			||||||
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
 | 
					from prototorch.core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
 | 
				
			||||||
from ..utils.utils import parse_data_arg
 | 
					from prototorch.utils.utils import parse_data_arg
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .abstract import SupervisedPrototypeModel
 | 
					from .abstract import SupervisedPrototypeModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,10 +1,11 @@
 | 
				
			|||||||
"""LVQ models that are optimized using non-gradient methods."""
 | 
					"""LVQ models that are optimized using non-gradient methods."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.losses import _get_dp_dm
 | 
					from prototorch.core.losses import _get_dp_dm
 | 
				
			||||||
from ..nn.activations import get_activation
 | 
					from prototorch.nn.activations import get_activation
 | 
				
			||||||
from ..nn.wrappers import LambdaLayer
 | 
					from prototorch.nn.wrappers import LambdaLayer
 | 
				
			||||||
from .abstract import NonGradientMixin
 | 
					
 | 
				
			||||||
from .glvq import GLVQ
 | 
					from .glvq import GLVQ
 | 
				
			||||||
 | 
					from .mixin import NonGradientMixin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LVQ1(NonGradientMixin, GLVQ):
 | 
					class LVQ1(NonGradientMixin, GLVQ):
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										27
									
								
								prototorch/models/mixin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								prototorch/models/mixin.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
				
			|||||||
 | 
					class ProtoTorchMixin:
 | 
				
			||||||
 | 
					    """All mixins are ProtoTorchMixins."""
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NonGradientMixin(ProtoTorchMixin):
 | 
				
			||||||
 | 
					    """Mixin for custom non-gradient optimization."""
 | 
				
			||||||
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					        self.automatic_optimization = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def training_step(self, train_batch, batch_idx, optimizer_idx=None):
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ImagePrototypesMixin(ProtoTorchMixin):
 | 
				
			||||||
 | 
					    """Mixin for models with image prototypes."""
 | 
				
			||||||
 | 
					    def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
 | 
				
			||||||
 | 
					        """Constrain the components to the range [0, 1] by clamping after updates."""
 | 
				
			||||||
 | 
					        self.proto_layer.components.data.clamp_(0.0, 1.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_prototype_grid(self, num_columns=2, return_channels_last=True):
 | 
				
			||||||
 | 
					        from torchvision.utils import make_grid
 | 
				
			||||||
 | 
					        grid = make_grid(self.components, nrow=num_columns)
 | 
				
			||||||
 | 
					        if return_channels_last:
 | 
				
			||||||
 | 
					            grid = grid.permute((1, 2, 0))
 | 
				
			||||||
 | 
					        return grid.cpu()
 | 
				
			||||||
@@ -1,10 +1,10 @@
 | 
				
			|||||||
"""Probabilistic GLVQ methods"""
 | 
					"""Probabilistic GLVQ methods"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.core.losses import nllr_loss, rslvq_loss
 | 
				
			||||||
 | 
					from prototorch.core.pooling import stratified_min_pooling, stratified_sum_pooling
 | 
				
			||||||
 | 
					from prototorch.nn.wrappers import LambdaLayer, LossLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.losses import nllr_loss, rslvq_loss
 | 
					 | 
				
			||||||
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
 | 
					 | 
				
			||||||
from ..nn.wrappers import LambdaLayer, LossLayer
 | 
					 | 
				
			||||||
from .extras import GaussianPrior, RankScaledGaussianPrior
 | 
					from .extras import GaussianPrior, RankScaledGaussianPrior
 | 
				
			||||||
from .glvq import GLVQ, SiameseGMLVQ
 | 
					from .glvq import GLVQ, SiameseGMLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,14 +2,15 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					from prototorch.core.competitions import wtac
 | 
				
			||||||
 | 
					from prototorch.core.distances import squared_euclidean_distance
 | 
				
			||||||
 | 
					from prototorch.core.losses import NeuralGasEnergy
 | 
				
			||||||
 | 
					from prototorch.nn.wrappers import LambdaLayer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..core.competitions import wtac
 | 
					from .abstract import UnsupervisedPrototypeModel
 | 
				
			||||||
from ..core.distances import squared_euclidean_distance
 | 
					 | 
				
			||||||
from ..core.losses import NeuralGasEnergy
 | 
					 | 
				
			||||||
from ..nn.wrappers import LambdaLayer
 | 
					 | 
				
			||||||
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
 | 
					 | 
				
			||||||
from .callbacks import GNGCallback
 | 
					from .callbacks import GNGCallback
 | 
				
			||||||
from .extras import ConnectionTopology
 | 
					from .extras import ConnectionTopology
 | 
				
			||||||
 | 
					from .mixin import NonGradientMixin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
 | 
					class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,15 +5,18 @@ import pytorch_lightning as pl
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchvision
 | 
					import torchvision
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					from prototorch.utils.utils import generate_mesh, mesh2d
 | 
				
			||||||
from torch.utils.data import DataLoader, Dataset
 | 
					from torch.utils.data import DataLoader, Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..utils.utils import mesh2d
 | 
					COLOR_UNLABELED = 'w'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Vis2DAbstract(pl.Callback):
 | 
					class Vis2DAbstract(pl.Callback):
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 data,
 | 
					                 data,
 | 
				
			||||||
                 title="Prototype Visualization",
 | 
					                 title=None,
 | 
				
			||||||
 | 
					                 x_label=None,
 | 
				
			||||||
 | 
					                 y_label=None,
 | 
				
			||||||
                 cmap="viridis",
 | 
					                 cmap="viridis",
 | 
				
			||||||
                 border=0.1,
 | 
					                 border=0.1,
 | 
				
			||||||
                 resolution=100,
 | 
					                 resolution=100,
 | 
				
			||||||
@@ -45,6 +48,8 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
        self.y_train = y
 | 
					        self.y_train = y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.title = title
 | 
					        self.title = title
 | 
				
			||||||
 | 
					        self.x_label = x_label
 | 
				
			||||||
 | 
					        self.y_label = y_label
 | 
				
			||||||
        self.fig = plt.figure(self.title)
 | 
					        self.fig = plt.figure(self.title)
 | 
				
			||||||
        self.cmap = cmap
 | 
					        self.cmap = cmap
 | 
				
			||||||
        self.border = border
 | 
					        self.border = border
 | 
				
			||||||
@@ -57,20 +62,19 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
        self.pause_time = pause_time
 | 
					        self.pause_time = pause_time
 | 
				
			||||||
        self.block = block
 | 
					        self.block = block
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def precheck(self, trainer):
 | 
					    def show_on_current_epoch(self, trainer):
 | 
				
			||||||
        if self.show_last_only:
 | 
					        if self.show_last_only and trainer.current_epoch != trainer.max_epochs - 1:
 | 
				
			||||||
            if trainer.current_epoch != trainer.max_epochs - 1:
 | 
					            return False
 | 
				
			||||||
                return False
 | 
					 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def setup_ax(self, xlabel=None, ylabel=None):
 | 
					    def setup_ax(self):
 | 
				
			||||||
        ax = self.fig.gca()
 | 
					        ax = self.fig.gca()
 | 
				
			||||||
        ax.cla()
 | 
					        ax.cla()
 | 
				
			||||||
        ax.set_title(self.title)
 | 
					        ax.set_title(self.title)
 | 
				
			||||||
        if xlabel:
 | 
					        if self.x_label:
 | 
				
			||||||
            ax.set_xlabel("Data dimension 1")
 | 
					            ax.set_xlabel(self.x_label)
 | 
				
			||||||
        if ylabel:
 | 
					        if self.x_label:
 | 
				
			||||||
            ax.set_ylabel("Data dimension 2")
 | 
					            ax.set_ylabel(self.y_label)
 | 
				
			||||||
        if self.axis_off:
 | 
					        if self.axis_off:
 | 
				
			||||||
            ax.axis("off")
 | 
					            ax.axis("off")
 | 
				
			||||||
        return ax
 | 
					        return ax
 | 
				
			||||||
@@ -117,25 +121,64 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
        plt.close()
 | 
					        plt.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisGLVQ2D(Vis2DAbstract):
 | 
					class Visualize2DVoronoiCallback(Vis2DAbstract):
 | 
				
			||||||
 | 
					    def __init__(self, data, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(data, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.data_min = torch.min(self.x_train, axis=0).values
 | 
				
			||||||
 | 
					        self.data_max = torch.max(self.x_train, axis=0).values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def current_span(self, proto_values):
 | 
				
			||||||
 | 
					        proto_min = torch.min(proto_values, axis=0).values
 | 
				
			||||||
 | 
					        proto_max = torch.max(proto_values, axis=0).values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        overall_min = torch.minimum(proto_min, self.data_min)
 | 
				
			||||||
 | 
					        overall_max = torch.maximum(proto_max, self.data_max)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return overall_min, overall_max
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_voronoi_diagram(self, min, max, model):
 | 
				
			||||||
 | 
					        mesh_input, (xx, yy) = generate_mesh(
 | 
				
			||||||
 | 
					            min,
 | 
				
			||||||
 | 
					            max,
 | 
				
			||||||
 | 
					            border=self.border,
 | 
				
			||||||
 | 
					            resolution=self.resolution,
 | 
				
			||||||
 | 
					            device=model.device,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        y_pred = model.predict(mesh_input)
 | 
				
			||||||
 | 
					        return xx, yy, y_pred.reshape(xx.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        if not self.precheck(trainer):
 | 
					        if not self.show_on_current_epoch(trainer):
 | 
				
			||||||
            return True
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        protos = pl_module.prototypes
 | 
					        # Extract Prototypes
 | 
				
			||||||
        plabels = pl_module.prototype_labels
 | 
					        proto_values = pl_module.prototypes
 | 
				
			||||||
        x_train, y_train = self.x_train, self.y_train
 | 
					        if hasattr(pl_module, "prototype_labels"):
 | 
				
			||||||
        ax = self.setup_ax(xlabel="Data dimension 1",
 | 
					            proto_labels = pl_module.prototype_labels
 | 
				
			||||||
                           ylabel="Data dimension 2")
 | 
					        else:
 | 
				
			||||||
        self.plot_data(ax, x_train, y_train)
 | 
					            proto_labels = COLOR_UNLABELED
 | 
				
			||||||
        self.plot_protos(ax, protos, plabels)
 | 
					
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        # Calculate Voronoi Diagram
 | 
				
			||||||
        mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
 | 
					        overall_min, overall_max = self.current_span(proto_values)
 | 
				
			||||||
        _components = pl_module.proto_layer._components
 | 
					        xx, yy, y_pred = self.get_voronoi_diagram(
 | 
				
			||||||
        mesh_input = torch.from_numpy(mesh_input).type_as(_components)
 | 
					            overall_min,
 | 
				
			||||||
        y_pred = pl_module.predict(mesh_input)
 | 
					            overall_max,
 | 
				
			||||||
        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
					            pl_module,
 | 
				
			||||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
 | 
					        ax.contourf(
 | 
				
			||||||
 | 
					            xx.cpu(),
 | 
				
			||||||
 | 
					            yy.cpu(),
 | 
				
			||||||
 | 
					            y_pred.cpu(),
 | 
				
			||||||
 | 
					            cmap=self.cmap,
 | 
				
			||||||
 | 
					            alpha=0.35,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.plot_data(ax, self.x_train, self.y_train)
 | 
				
			||||||
 | 
					        self.plot_protos(ax, proto_values, proto_labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.log_and_display(trainer, pl_module)
 | 
					        self.log_and_display(trainer, pl_module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -146,7 +189,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        self.map_protos = map_protos
 | 
					        self.map_protos = map_protos
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        if not self.precheck(trainer):
 | 
					        if not self.show_on_current_epoch(trainer):
 | 
				
			||||||
            return True
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        protos = pl_module.prototypes
 | 
					        protos = pl_module.prototypes
 | 
				
			||||||
@@ -184,7 +227,7 @@ class VisGMLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        self.ev_proj = ev_proj
 | 
					        self.ev_proj = ev_proj
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        if not self.precheck(trainer):
 | 
					        if not self.show_on_current_epoch(trainer):
 | 
				
			||||||
            return True
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        protos = pl_module.prototypes
 | 
					        protos = pl_module.prototypes
 | 
				
			||||||
@@ -211,40 +254,16 @@ class VisGMLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        self.log_and_display(trainer, pl_module)
 | 
					        self.log_and_display(trainer, pl_module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisCBC2D(Vis2DAbstract):
 | 
					 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					 | 
				
			||||||
        if not self.precheck(trainer):
 | 
					 | 
				
			||||||
            return True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        x_train, y_train = self.x_train, self.y_train
 | 
					 | 
				
			||||||
        protos = pl_module.components
 | 
					 | 
				
			||||||
        ax = self.setup_ax(xlabel="Data dimension 1",
 | 
					 | 
				
			||||||
                           ylabel="Data dimension 2")
 | 
					 | 
				
			||||||
        self.plot_data(ax, x_train, y_train)
 | 
					 | 
				
			||||||
        self.plot_protos(ax, protos, "w")
 | 
					 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					 | 
				
			||||||
        mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
 | 
					 | 
				
			||||||
        _components = pl_module.components_layer._components
 | 
					 | 
				
			||||||
        y_pred = pl_module.predict(
 | 
					 | 
				
			||||||
            torch.Tensor(mesh_input).type_as(_components))
 | 
					 | 
				
			||||||
        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.log_and_display(trainer, pl_module)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class VisNG2D(Vis2DAbstract):
 | 
					class VisNG2D(Vis2DAbstract):
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        if not self.precheck(trainer):
 | 
					        if not self.show_on_current_epoch(trainer):
 | 
				
			||||||
            return True
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        x_train, y_train = self.x_train, self.y_train
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
        protos = pl_module.prototypes
 | 
					        protos = pl_module.prototypes
 | 
				
			||||||
        cmat = pl_module.topology_layer.cmat.cpu().numpy()
 | 
					        cmat = pl_module.topology_layer.cmat.cpu().numpy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ax = self.setup_ax(xlabel="Data dimension 1",
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
                           ylabel="Data dimension 2")
 | 
					 | 
				
			||||||
        self.plot_data(ax, x_train, y_train)
 | 
					        self.plot_data(ax, x_train, y_train)
 | 
				
			||||||
        self.plot_protos(ax, protos, "w")
 | 
					        self.plot_protos(ax, protos, "w")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -315,7 +334,7 @@ class VisImgComp(Vis2DAbstract):
 | 
				
			|||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
        if not self.precheck(trainer):
 | 
					        if not self.show_on_current_epoch(trainer):
 | 
				
			||||||
            return True
 | 
					            return True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.show:
 | 
					        if self.show:
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@@ -18,7 +18,7 @@ PLUGIN_NAME = "models"
 | 
				
			|||||||
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
 | 
					PROJECT_URL = "https://github.com/si-cim/prototorch_models"
 | 
				
			||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
 | 
					DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
with open("README.md", "r") as fh:
 | 
					with open("README.md") as fh:
 | 
				
			||||||
    long_description = fh.read()
 | 
					    long_description = fh.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
INSTALL_REQUIRES = [
 | 
					INSTALL_REQUIRES = [
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user