Compare commits
	
		
			14 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | 696719600b | ||
|  | 48e7c029fa | ||
|  | 5de3a480c7 | ||
|  | 626f51ce80 | ||
|  | 6d7d93c8e8 | ||
|  | 93b1d0bd46 | ||
|  | b7992c01db | ||
|  | 23d1a71b31 | ||
|  | e922aae432 | ||
|  | 3e50d0d817 | ||
|  | dc4f31d700 | ||
|  | 02954044d7 | ||
|  | 8f08ba66ea | ||
|  | e0b92e9ac2 | 
| @@ -1,9 +1,11 @@ | ||||
| [bumpversion] | ||||
| current_version = 0.5.2 | ||||
| current_version = 1.0.0a2 | ||||
| commit = True | ||||
| tag = True | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) | ||||
| serialize = {major}.{minor}.{patch} | ||||
| parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))? | ||||
| serialize =  | ||||
| 	{major}.{minor}.{patch}-{release} | ||||
| 	{major}.{minor}.{patch} | ||||
| message = build: bump version {current_version} → {new_version} | ||||
|  | ||||
| [bumpversion:file:setup.py] | ||||
|   | ||||
| @@ -6,6 +6,7 @@ repos: | ||||
|   rev: v4.2.0 | ||||
|   hooks: | ||||
|   - id: trailing-whitespace | ||||
|     exclude: (^\.bumpversion\.cfg$|cli_messages\.py) | ||||
|   - id: end-of-file-fixer | ||||
|   - id: check-yaml | ||||
|   - id: check-added-large-files | ||||
|   | ||||
| @@ -23,7 +23,7 @@ author = "Jensun Ravichandran" | ||||
|  | ||||
| # The full version, including alpha/beta/rc tags | ||||
| # | ||||
| release = "0.5.2" | ||||
| release = "1.0.0-a2" | ||||
|  | ||||
| # -- General configuration --------------------------------------------------- | ||||
|  | ||||
|   | ||||
							
								
								
									
										88
									
								
								examples/y_architecture_example.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								examples/y_architecture_example.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,88 @@ | ||||
| import prototorch as pt | ||||
| import pytorch_lightning as pl | ||||
| import torchmetrics | ||||
| from prototorch.core import SMCI | ||||
| from prototorch.y.callbacks import ( | ||||
|     LogTorchmetricCallback, | ||||
|     PlotLambdaMatrixToTensorboard, | ||||
|     VisGMLVQ2D, | ||||
| ) | ||||
| from prototorch.y.library.gmlvq import GMLVQ | ||||
| from pytorch_lightning.callbacks import EarlyStopping | ||||
| from torch.utils.data import DataLoader | ||||
|  | ||||
| # ############################################################################## | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     # ------------------------------------------------------------ | ||||
|     # DATA | ||||
|     # ------------------------------------------------------------ | ||||
|  | ||||
|     # Dataset | ||||
|     train_ds = pt.datasets.Iris() | ||||
|  | ||||
|     # Dataloader | ||||
|     train_loader = DataLoader( | ||||
|         train_ds, | ||||
|         batch_size=32, | ||||
|         num_workers=0, | ||||
|         shuffle=True, | ||||
|     ) | ||||
|  | ||||
|     # ------------------------------------------------------------ | ||||
|     # HYPERPARAMETERS | ||||
|     # ------------------------------------------------------------ | ||||
|  | ||||
|     # Select Initializer | ||||
|     components_initializer = SMCI(train_ds) | ||||
|  | ||||
|     # Define Hyperparameters | ||||
|     hyperparameters = GMLVQ.HyperParameters( | ||||
|         lr=dict(components_layer=0.1, _omega=0), | ||||
|         input_dim=4, | ||||
|         distribution=dict( | ||||
|             num_classes=3, | ||||
|             per_class=1, | ||||
|         ), | ||||
|         component_initializer=components_initializer, | ||||
|     ) | ||||
|  | ||||
|     # Create Model | ||||
|     model = GMLVQ(hyperparameters) | ||||
|  | ||||
|     print(model) | ||||
|  | ||||
|     # ------------------------------------------------------------ | ||||
|     # TRAINING | ||||
|     # ------------------------------------------------------------ | ||||
|  | ||||
|     # Controlling Callbacks | ||||
|     stopping_criterion = LogTorchmetricCallback( | ||||
|         'recall', | ||||
|         torchmetrics.Recall, | ||||
|         num_classes=3, | ||||
|     ) | ||||
|  | ||||
|     es = EarlyStopping( | ||||
|         monitor=stopping_criterion.name, | ||||
|         mode="max", | ||||
|         patience=10, | ||||
|     ) | ||||
|  | ||||
|     # Visualization Callback | ||||
|     vis = VisGMLVQ2D(data=train_ds) | ||||
|  | ||||
|     # Define trainer | ||||
|     trainer = pl.Trainer( | ||||
|         callbacks=[ | ||||
|             vis, | ||||
|             stopping_criterion, | ||||
|             es, | ||||
|             PlotLambdaMatrixToTensorboard(), | ||||
|         ], | ||||
|         max_epochs=1000, | ||||
|     ) | ||||
|  | ||||
|     # Train | ||||
|     trainer.fit(model, train_loader) | ||||
| @@ -36,4 +36,4 @@ from .unsupervised import ( | ||||
| ) | ||||
| from .vis import * | ||||
|  | ||||
| __version__ = "0.5.2" | ||||
| __version__ = "1.0.0-a2" | ||||
|   | ||||
| @@ -22,7 +22,16 @@ from prototorch.nn.wrappers import LambdaLayer | ||||
|  | ||||
|  | ||||
| class ProtoTorchBolt(pl.LightningModule): | ||||
|     """All ProtoTorch models are ProtoTorch Bolts.""" | ||||
|     """All ProtoTorch models are ProtoTorch Bolts. | ||||
|  | ||||
|     hparams: | ||||
|         - lr: learning rate | ||||
|  | ||||
|     kwargs: | ||||
|         - optimizer: optimizer class | ||||
|         - lr_scheduler: learning rate scheduler class | ||||
|         - lr_scheduler_kwargs: learning rate scheduler kwargs | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, hparams, **kwargs): | ||||
|         super().__init__() | ||||
| @@ -65,6 +74,11 @@ class ProtoTorchBolt(pl.LightningModule): | ||||
|  | ||||
|  | ||||
| class PrototypeModel(ProtoTorchBolt): | ||||
|     """Abstract Prototype Model | ||||
|  | ||||
|     kwargs: | ||||
|         - distance_fn: distance function | ||||
|     """ | ||||
|     proto_layer: AbstractComponents | ||||
|  | ||||
|     def __init__(self, hparams, **kwargs): | ||||
| @@ -203,35 +217,3 @@ class SupervisedPrototypeModel(PrototypeModel): | ||||
|         accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) | ||||
|  | ||||
|         self.log("test_acc", accuracy) | ||||
|  | ||||
|  | ||||
| class ProtoTorchMixin(object): | ||||
|     """All mixins are ProtoTorchMixins.""" | ||||
|  | ||||
|  | ||||
| class NonGradientMixin(ProtoTorchMixin): | ||||
|     """Mixin for custom non-gradient optimization.""" | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.automatic_optimization = False | ||||
|  | ||||
|     def training_step(self, train_batch, batch_idx, optimizer_idx=None): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class ImagePrototypesMixin(ProtoTorchMixin): | ||||
|     """Mixin for models with image prototypes.""" | ||||
|     proto_layer: Components | ||||
|     components: torch.Tensor | ||||
|  | ||||
|     def on_train_batch_end(self, outputs, batch, batch_idx): | ||||
|         """Constrain the components to the range [0, 1] by clamping after updates.""" | ||||
|         self.proto_layer.components.data.clamp_(0.0, 1.0) | ||||
|  | ||||
|     def get_prototype_grid(self, num_columns=2, return_channels_last=True): | ||||
|         from torchvision.utils import make_grid | ||||
|         grid = make_grid(self.components, nrow=num_columns) | ||||
|         if return_channels_last: | ||||
|             grid = grid.permute((1, 2, 0)) | ||||
|         return grid.cpu() | ||||
|   | ||||
| @@ -40,8 +40,8 @@ class PruneLoserPrototypes(pl.Callback): | ||||
|             return None | ||||
|  | ||||
|         ratios = pl_module.prototype_win_ratios.mean(dim=0) | ||||
|         to_prune = torch.arange(len(ratios))[ratios < self.threshold] | ||||
|         to_prune = to_prune.tolist() | ||||
|         to_prune_tensor = torch.arange(len(ratios))[ratios < self.threshold] | ||||
|         to_prune = to_prune_tensor.tolist() | ||||
|         prune_labels = pl_module.prototype_labels[to_prune] | ||||
|         if self.prune_quota_per_epoch > 0: | ||||
|             to_prune = to_prune[:self.prune_quota_per_epoch] | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import torchmetrics | ||||
| from prototorch.core.competitions import CBCC | ||||
| from prototorch.core.components import ReasoningComponents | ||||
| @@ -7,12 +8,13 @@ from prototorch.core.losses import MarginLoss | ||||
| from prototorch.core.similarities import euclidean_similarity | ||||
| from prototorch.nn.wrappers import LambdaLayer | ||||
|  | ||||
| from .abstract import ImagePrototypesMixin | ||||
| from .glvq import SiameseGLVQ | ||||
| from .mixins import ImagePrototypesMixin | ||||
|  | ||||
|  | ||||
| class CBC(SiameseGLVQ): | ||||
|     """Classification-By-Components.""" | ||||
|     proto_layer: ReasoningComponents | ||||
|  | ||||
|     def __init__(self, hparams, **kwargs): | ||||
|         super().__init__(hparams, skip_proto_layer=True, **kwargs) | ||||
| @@ -22,7 +24,7 @@ class CBC(SiameseGLVQ): | ||||
|         reasonings_initializer = kwargs.get("reasonings_initializer", | ||||
|                                             RandomReasoningsInitializer()) | ||||
|         self.components_layer = ReasoningComponents( | ||||
|             self.hparams.distribution, | ||||
|             self.hparams["distribution"], | ||||
|             components_initializer=components_initializer, | ||||
|             reasonings_initializer=reasonings_initializer, | ||||
|         ) | ||||
| @@ -32,7 +34,7 @@ class CBC(SiameseGLVQ): | ||||
|         # Namespace hook | ||||
|         self.proto_layer = self.components_layer | ||||
|  | ||||
|         self.loss = MarginLoss(self.hparams.margin) | ||||
|         self.loss = MarginLoss(self.hparams["margin"]) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         components, reasonings = self.components_layer() | ||||
| @@ -48,7 +50,7 @@ class CBC(SiameseGLVQ): | ||||
|         x, y = batch | ||||
|         y_pred = self(x) | ||||
|         num_classes = self.num_classes | ||||
|         y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes) | ||||
|         y_true = F.one_hot(y.long(), num_classes=num_classes) | ||||
|         loss = self.loss(y_pred, y_true).mean() | ||||
|         return y_pred, loss | ||||
|  | ||||
|   | ||||
| @@ -17,8 +17,9 @@ from prototorch.core.transforms import LinearTransform | ||||
| from prototorch.nn.wrappers import LambdaLayer, LossLayer | ||||
| from torch.nn.parameter import Parameter | ||||
|  | ||||
| from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel | ||||
| from .abstract import SupervisedPrototypeModel | ||||
| from .extras import ltangent_distance, orthogonalization | ||||
| from .mixins import ImagePrototypesMixin | ||||
|  | ||||
|  | ||||
| class GLVQ(SupervisedPrototypeModel): | ||||
| @@ -46,19 +47,24 @@ class GLVQ(SupervisedPrototypeModel): | ||||
|     def initialize_prototype_win_ratios(self): | ||||
|         self.register_buffer( | ||||
|             "prototype_win_ratios", | ||||
|             torch.zeros(self.num_prototypes, device=self.device)) | ||||
|             torch.zeros(self.num_prototypes, device=self.device), | ||||
|         ) | ||||
|  | ||||
|     def on_train_epoch_start(self): | ||||
|         self.initialize_prototype_win_ratios() | ||||
|  | ||||
|     def log_prototype_win_ratios(self, distances): | ||||
|         batch_size = len(distances) | ||||
|         prototype_wc = torch.zeros(self.num_prototypes, | ||||
|         prototype_wc = torch.zeros( | ||||
|             self.num_prototypes, | ||||
|             dtype=torch.long, | ||||
|                                    device=self.device) | ||||
|         wi, wc = torch.unique(distances.min(dim=-1).indices, | ||||
|             device=self.device, | ||||
|         ) | ||||
|         wi, wc = torch.unique( | ||||
|             distances.min(dim=-1).indices, | ||||
|             sorted=True, | ||||
|                               return_counts=True) | ||||
|             return_counts=True, | ||||
|         ) | ||||
|         prototype_wc[wi] = wc | ||||
|         prototype_wr = prototype_wc / batch_size | ||||
|         self.prototype_win_ratios = torch.vstack([ | ||||
| @@ -81,14 +87,12 @@ class GLVQ(SupervisedPrototypeModel): | ||||
|         return train_loss | ||||
|  | ||||
|     def validation_step(self, batch, batch_idx): | ||||
|         # `model.eval()` and `torch.no_grad()` handled by pl | ||||
|         out, val_loss = self.shared_step(batch, batch_idx) | ||||
|         self.log("val_loss", val_loss) | ||||
|         self.log_acc(out, batch[-1], tag="val_acc") | ||||
|         return val_loss | ||||
|  | ||||
|     def test_step(self, batch, batch_idx): | ||||
|         # `model.eval()` and `torch.no_grad()` handled by pl | ||||
|         out, test_loss = self.shared_step(batch, batch_idx) | ||||
|         self.log_acc(out, batch[-1], tag="test_acc") | ||||
|         return test_loss | ||||
| @@ -99,10 +103,6 @@ class GLVQ(SupervisedPrototypeModel): | ||||
|             test_loss += batch_loss.item() | ||||
|         self.log("test_loss", test_loss) | ||||
|  | ||||
|     # TODO | ||||
|     # def predict_step(self, batch, batch_idx, dataloader_idx=None): | ||||
|     #     pass | ||||
|  | ||||
|  | ||||
| class SiameseGLVQ(GLVQ): | ||||
|     """GLVQ in a Siamese setting. | ||||
| @@ -113,19 +113,23 @@ class SiameseGLVQ(GLVQ): | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|     def __init__( | ||||
|             self, | ||||
|             hparams, | ||||
|             backbone=torch.nn.Identity(), | ||||
|             both_path_gradients=False, | ||||
|                  **kwargs): | ||||
|             **kwargs, | ||||
|     ): | ||||
|         distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance) | ||||
|         super().__init__(hparams, distance_fn=distance_fn, **kwargs) | ||||
|         self.backbone = backbone | ||||
|         self.both_path_gradients = both_path_gradients | ||||
|  | ||||
|     def configure_optimizers(self): | ||||
|         proto_opt = self.optimizer(self.proto_layer.parameters(), | ||||
|                                    lr=self.hparams["proto_lr"]) | ||||
|         proto_opt = self.optimizer( | ||||
|             self.proto_layer.parameters(), | ||||
|             lr=self.hparams["proto_lr"], | ||||
|         ) | ||||
|         # Only add a backbone optimizer if backbone has trainable parameters | ||||
|         bb_params = list(self.backbone.parameters()) | ||||
|         if (bb_params): | ||||
| @@ -266,13 +270,19 @@ class GMLVQ(GLVQ): | ||||
|         super().__init__(hparams, distance_fn=distance_fn, **kwargs) | ||||
|  | ||||
|         # Additional parameters | ||||
|         omega_initializer = kwargs.get("omega_initializer", | ||||
|                                        EyeLinearTransformInitializer()) | ||||
|         omega = omega_initializer.generate(self.hparams["input_dim"], | ||||
|                                            self.hparams["latent_dim"]) | ||||
|         omega_initializer = kwargs.get( | ||||
|             "omega_initializer", | ||||
|             EyeLinearTransformInitializer(), | ||||
|         ) | ||||
|         omega = omega_initializer.generate( | ||||
|             self.hparams["input_dim"], | ||||
|             self.hparams["latent_dim"], | ||||
|         ) | ||||
|         self.register_parameter("_omega", Parameter(omega)) | ||||
|         self.backbone = LambdaLayer(lambda x: x @ self._omega, | ||||
|                                     name="omega matrix") | ||||
|         self.backbone = LambdaLayer( | ||||
|             lambda x: x @ self._omega, | ||||
|             name="omega matrix", | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def omega_matrix(self): | ||||
|   | ||||
| @@ -1,20 +1,21 @@ | ||||
| """LVQ models that are optimized using non-gradient methods.""" | ||||
|  | ||||
| import logging | ||||
| from collections import OrderedDict | ||||
|  | ||||
| from prototorch.core.losses import _get_dp_dm | ||||
| from prototorch.nn.activations import get_activation | ||||
| from prototorch.nn.wrappers import LambdaLayer | ||||
|  | ||||
| from .abstract import NonGradientMixin | ||||
| from .glvq import GLVQ | ||||
| from .mixins import NonGradientMixin | ||||
|  | ||||
|  | ||||
| class LVQ1(NonGradientMixin, GLVQ): | ||||
|     """Learning Vector Quantization 1.""" | ||||
|  | ||||
|     def training_step(self, train_batch, batch_idx, optimizer_idx=None): | ||||
|         protos, plables = self.proto_layer() | ||||
|         protos, plabels = self.proto_layer() | ||||
|         x, y = train_batch | ||||
|         dis = self.compute_distances(x) | ||||
|         # TODO Vectorized implementation | ||||
| @@ -28,9 +29,11 @@ class LVQ1(NonGradientMixin, GLVQ): | ||||
|             else: | ||||
|                 shift = protos[w] - xi | ||||
|             updated_protos = protos + 0.0 | ||||
|             updated_protos[w] = protos[w] + (self.hparams.lr * shift) | ||||
|             self.proto_layer.load_state_dict({"_components": updated_protos}, | ||||
|                                              strict=False) | ||||
|             updated_protos[w] = protos[w] + (self.hparams["lr"] * shift) | ||||
|             self.proto_layer.load_state_dict( | ||||
|                 OrderedDict(_components=updated_protos), | ||||
|                 strict=False, | ||||
|             ) | ||||
|  | ||||
|         logging.debug(f"dis={dis}") | ||||
|         logging.debug(f"y={y}") | ||||
| @@ -58,10 +61,12 @@ class LVQ21(NonGradientMixin, GLVQ): | ||||
|             shiftp = xi - protos[wp] | ||||
|             shiftn = protos[wn] - xi | ||||
|             updated_protos = protos + 0.0 | ||||
|             updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp) | ||||
|             updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn) | ||||
|             self.proto_layer.load_state_dict({"_components": updated_protos}, | ||||
|                                              strict=False) | ||||
|             updated_protos[wp] = protos[wp] + (self.hparams["lr"] * shiftp) | ||||
|             updated_protos[wn] = protos[wn] + (self.hparams["lr"] * shiftn) | ||||
|             self.proto_layer.load_state_dict( | ||||
|                 OrderedDict(_components=updated_protos), | ||||
|                 strict=False, | ||||
|             ) | ||||
|  | ||||
|         # Logging | ||||
|         self.log_acc(dis, y, tag="train_acc") | ||||
| @@ -80,14 +85,17 @@ class MedianLVQ(NonGradientMixin, GLVQ): | ||||
|         super().__init__(hparams, **kwargs) | ||||
|  | ||||
|         self.transfer_layer = LambdaLayer( | ||||
|             get_activation(self.hparams.transfer_fn)) | ||||
|             get_activation(self.hparams["transfer_fn"])) | ||||
|  | ||||
|     def _f(self, x, y, protos, plabels): | ||||
|         d = self.distance_layer(x, protos) | ||||
|         dp, dm = _get_dp_dm(d, y, plabels) | ||||
|         dp, dm = _get_dp_dm(d, y, plabels, with_indices=False) | ||||
|         mu = (dp - dm) / (dp + dm) | ||||
|         invmu = -1.0 * mu | ||||
|         f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0 | ||||
|         negative_mu = -1.0 * mu | ||||
|         f = self.transfer_layer( | ||||
|             negative_mu, | ||||
|             beta=self.hparams["transfer_beta"], | ||||
|         ) + 1.0 | ||||
|         return f | ||||
|  | ||||
|     def expectation(self, x, y, protos, plabels): | ||||
| @@ -118,8 +126,10 @@ class MedianLVQ(NonGradientMixin, GLVQ): | ||||
|                 _lower_bound = self.lower_bound(x, y, _protos, plabels, gamma) | ||||
|                 if _lower_bound > lower_bound: | ||||
|                     logging.debug(f"Updating prototype {i} to data {k}...") | ||||
|                     self.proto_layer.load_state_dict({"_components": _protos}, | ||||
|                                                      strict=False) | ||||
|                     self.proto_layer.load_state_dict( | ||||
|                         OrderedDict(_components=_protos), | ||||
|                         strict=False, | ||||
|                     ) | ||||
|                     break | ||||
|  | ||||
|         # Logging | ||||
|   | ||||
							
								
								
									
										35
									
								
								prototorch/models/mixins.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								prototorch/models/mixins.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| from prototorch.core.components import Components | ||||
|  | ||||
|  | ||||
| class ProtoTorchMixin(pl.LightningModule): | ||||
|     """All mixins are ProtoTorchMixins.""" | ||||
|  | ||||
|  | ||||
| class NonGradientMixin(ProtoTorchMixin): | ||||
|     """Mixin for custom non-gradient optimization.""" | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.automatic_optimization = False | ||||
|  | ||||
|     def training_step(self, train_batch, batch_idx, optimizer_idx=None): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class ImagePrototypesMixin(ProtoTorchMixin): | ||||
|     """Mixin for models with image prototypes.""" | ||||
|     proto_layer: Components | ||||
|     components: torch.Tensor | ||||
|  | ||||
|     def on_train_batch_end(self, outputs, batch, batch_idx): | ||||
|         """Constrain the components to the range [0, 1] by clamping after updates.""" | ||||
|         self.proto_layer.components.data.clamp_(0.0, 1.0) | ||||
|  | ||||
|     def get_prototype_grid(self, num_columns=2, return_channels_last=True): | ||||
|         from torchvision.utils import make_grid | ||||
|         grid = make_grid(self.components, nrow=num_columns) | ||||
|         if return_channels_last: | ||||
|             grid = grid.permute((1, 2, 0)) | ||||
|         return grid.cpu() | ||||
| @@ -6,9 +6,10 @@ from prototorch.core.competitions import wtac | ||||
| from prototorch.core.distances import squared_euclidean_distance | ||||
| from prototorch.core.losses import NeuralGasEnergy | ||||
|  | ||||
| from .abstract import NonGradientMixin, UnsupervisedPrototypeModel | ||||
| from .abstract import UnsupervisedPrototypeModel | ||||
| from .callbacks import GNGCallback | ||||
| from .extras import ConnectionTopology | ||||
| from .mixins import NonGradientMixin | ||||
|  | ||||
|  | ||||
| class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): | ||||
|   | ||||
| @@ -1,5 +1,6 @@ | ||||
| """Visualization Callbacks.""" | ||||
|  | ||||
| import os | ||||
| import warnings | ||||
| from typing import Sized | ||||
|  | ||||
| @@ -32,6 +33,10 @@ class Vis2DAbstract(pl.Callback): | ||||
|                  tensorboard=False, | ||||
|                  show_last_only=False, | ||||
|                  pause_time=0.1, | ||||
|                  save=False, | ||||
|                  save_dir="./img", | ||||
|                  fig_size=(5, 4), | ||||
|                  dpi=500, | ||||
|                  block=False): | ||||
|         super().__init__() | ||||
|  | ||||
| @@ -75,8 +80,16 @@ class Vis2DAbstract(pl.Callback): | ||||
|         self.tensorboard = tensorboard | ||||
|         self.show_last_only = show_last_only | ||||
|         self.pause_time = pause_time | ||||
|         self.save = save | ||||
|         self.save_dir = save_dir | ||||
|         self.fig_size = fig_size | ||||
|         self.dpi = dpi | ||||
|         self.block = block | ||||
|  | ||||
|         if save: | ||||
|             if not os.path.exists(save_dir): | ||||
|                 os.makedirs(save_dir) | ||||
|  | ||||
|     def precheck(self, trainer): | ||||
|         if self.show_last_only: | ||||
|             if trainer.current_epoch != trainer.max_epochs - 1: | ||||
| @@ -125,6 +138,11 @@ class Vis2DAbstract(pl.Callback): | ||||
|     def log_and_display(self, trainer, pl_module): | ||||
|         if self.tensorboard: | ||||
|             self.add_to_tensorboard(trainer, pl_module) | ||||
|         if self.save: | ||||
|             plt.tight_layout() | ||||
|             self.fig.set_size_inches(*self.fig_size, forward=False) | ||||
|             plt.savefig(f"{self.save_dir}/{trainer.current_epoch}.png", | ||||
|                         dpi=self.dpi) | ||||
|         if self.show: | ||||
|             if not self.block: | ||||
|                 plt.pause(self.pause_time) | ||||
|   | ||||
							
								
								
									
										23
									
								
								prototorch/y/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								prototorch/y/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | ||||
| from .architectures.base import BaseYArchitecture | ||||
| from .architectures.comparison import ( | ||||
|     OmegaComparisonMixin, | ||||
|     SimpleComparisonMixin, | ||||
| ) | ||||
| from .architectures.competition import WTACompetitionMixin | ||||
| from .architectures.components import SupervisedArchitecture | ||||
| from .architectures.loss import GLVQLossMixin | ||||
| from .architectures.optimization import ( | ||||
|     MultipleLearningRateMixin, | ||||
|     SingleLearningRateMixin, | ||||
| ) | ||||
|  | ||||
| __all__ = [ | ||||
|     'BaseYArchitecture', | ||||
|     "OmegaComparisonMixin", | ||||
|     "SimpleComparisonMixin", | ||||
|     "SingleLearningRateMixin", | ||||
|     "MultipleLearningRateMixin", | ||||
|     "SupervisedArchitecture", | ||||
|     "WTACompetitionMixin", | ||||
|     "GLVQLossMixin", | ||||
| ] | ||||
							
								
								
									
										212
									
								
								prototorch/y/architectures/base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										212
									
								
								prototorch/y/architectures/base.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,212 @@ | ||||
| """ | ||||
| Proto Y Architecture | ||||
|  | ||||
| Network architecture for Component based Learning. | ||||
| """ | ||||
| from dataclasses import dataclass | ||||
| from typing import ( | ||||
|     Dict, | ||||
|     Set, | ||||
|     Type, | ||||
| ) | ||||
|  | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| from torchmetrics import Metric | ||||
| from torchmetrics.classification.accuracy import Accuracy | ||||
|  | ||||
|  | ||||
| class BaseYArchitecture(pl.LightningModule): | ||||
|  | ||||
|     @dataclass | ||||
|     class HyperParameters: | ||||
|         ... | ||||
|  | ||||
|     registered_metrics: Dict[Type[Metric], Metric] = {} | ||||
|     registered_metric_names: Dict[Type[Metric], Set[str]] = {} | ||||
|  | ||||
|     components_layer: torch.nn.Module | ||||
|  | ||||
|     def __init__(self, hparams) -> None: | ||||
|         super().__init__() | ||||
|  | ||||
|         # Common Steps | ||||
|         self.init_components(hparams) | ||||
|         self.init_latent(hparams) | ||||
|         self.init_comparison(hparams) | ||||
|         self.init_competition(hparams) | ||||
|  | ||||
|         # Train Steps | ||||
|         self.init_loss(hparams) | ||||
|  | ||||
|         # Inference Steps | ||||
|         self.init_inference(hparams) | ||||
|  | ||||
|         # Initialize Model Metrics | ||||
|         self.init_model_metrics() | ||||
|  | ||||
|     # internal API, called by models and callbacks | ||||
|     def register_torchmetric( | ||||
|         self, | ||||
|         name: str, | ||||
|         metric: Type[Metric], | ||||
|         **metric_kwargs, | ||||
|     ): | ||||
|         if metric not in self.registered_metrics: | ||||
|             self.registered_metrics[metric] = metric(**metric_kwargs) | ||||
|             self.registered_metric_names[metric] = {name} | ||||
|         else: | ||||
|             self.registered_metric_names[metric].add(name) | ||||
|  | ||||
|     # external API | ||||
|     def get_competition(self, batch, components): | ||||
|         latent_batch, latent_components = self.latent(batch, components) | ||||
|         # TODO: => Latent Hook | ||||
|         comparison_tensor = self.comparison(latent_batch, latent_components) | ||||
|         # TODO: => Comparison Hook | ||||
|         return comparison_tensor | ||||
|  | ||||
|     def forward(self, batch): | ||||
|         if isinstance(batch, torch.Tensor): | ||||
|             batch = (batch, None) | ||||
|         # TODO: manage different datatypes? | ||||
|         components = self.components_layer() | ||||
|         # TODO: => Component Hook | ||||
|         comparison_tensor = self.get_competition(batch, components) | ||||
|         # TODO: => Competition Hook | ||||
|         return self.inference(comparison_tensor, components) | ||||
|  | ||||
|     def predict(self, batch): | ||||
|         """ | ||||
|         Alias for forward | ||||
|         """ | ||||
|         return self.forward(batch) | ||||
|  | ||||
|     def forward_comparison(self, batch): | ||||
|         if isinstance(batch, torch.Tensor): | ||||
|             batch = (batch, None) | ||||
|         # TODO: manage different datatypes? | ||||
|         components = self.components_layer() | ||||
|         # TODO: => Component Hook | ||||
|         return self.get_competition(batch, components) | ||||
|  | ||||
|     def loss_forward(self, batch): | ||||
|         # TODO: manage different datatypes? | ||||
|         components = self.components_layer() | ||||
|         # TODO: => Component Hook | ||||
|         comparison_tensor = self.get_competition(batch, components) | ||||
|         # TODO: => Competition Hook | ||||
|         return self.loss(comparison_tensor, batch, components) | ||||
|  | ||||
|     # Empty Initialization | ||||
|     # TODO: Type hints | ||||
|     # TODO: Docs | ||||
|     def init_components(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_latent(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_comparison(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_competition(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_loss(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_inference(self, hparams: HyperParameters) -> None: | ||||
|         ... | ||||
|  | ||||
|     def init_model_metrics(self) -> None: | ||||
|         self.register_torchmetric('accuracy', Accuracy) | ||||
|  | ||||
|     # Empty Steps | ||||
|     # TODO: Type hints | ||||
|     def components(self): | ||||
|         """ | ||||
|         This step has no input. | ||||
|  | ||||
|         It returns the components. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "The components step has no reasonable default.") | ||||
|  | ||||
|     def latent(self, batch, components): | ||||
|         """ | ||||
|         The latent step receives the data batch and the components. | ||||
|         It can transform both by an arbitrary function. | ||||
|  | ||||
|         It returns the transformed batch and components, each of the same length as the original input. | ||||
|         """ | ||||
|         return batch, components | ||||
|  | ||||
|     def comparison(self, batch, components): | ||||
|         """ | ||||
|         Takes a batch of size N and the component set of size M. | ||||
|  | ||||
|         It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "The comparison step has no reasonable default.") | ||||
|  | ||||
|     def competition(self, comparison_measures, components): | ||||
|         """ | ||||
|         Takes the tensor of comparison measures. | ||||
|  | ||||
|         Assigns a competition vector to each class. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "The competition step has no reasonable default.") | ||||
|  | ||||
|     def loss(self, comparison_measures, batch, components): | ||||
|         """ | ||||
|         Takes the tensor of competition measures. | ||||
|  | ||||
|         Calculates a single loss value | ||||
|         """ | ||||
|         raise NotImplementedError("The loss step has no reasonable default.") | ||||
|  | ||||
|     def inference(self, comparison_measures, components): | ||||
|         """ | ||||
|         Takes the tensor of competition measures. | ||||
|  | ||||
|         Returns the inferred vector. | ||||
|         """ | ||||
|         raise NotImplementedError( | ||||
|             "The inference step has no reasonable default.") | ||||
|  | ||||
|     def update_metrics_step(self, batch): | ||||
|         x, y = batch | ||||
|  | ||||
|         # Prediction Metrics | ||||
|         preds = self(x) | ||||
|         for metric in self.registered_metrics: | ||||
|             instance = self.registered_metrics[metric].to(self.device) | ||||
|             instance(y, preds) | ||||
|  | ||||
|     def update_metrics_epoch(self): | ||||
|         for metric in self.registered_metrics: | ||||
|             instance = self.registered_metrics[metric].to(self.device) | ||||
|             value = instance.compute() | ||||
|  | ||||
|             for name in self.registered_metric_names[metric]: | ||||
|                 self.log(name, value) | ||||
|  | ||||
|             instance.reset() | ||||
|  | ||||
|     # Lightning Hooks | ||||
|     def training_step(self, batch, batch_idx, optimizer_idx=None): | ||||
|         self.update_metrics_step(batch) | ||||
|  | ||||
|         return self.loss_forward(batch) | ||||
|  | ||||
|     def training_epoch_end(self, outs) -> None: | ||||
|         self.update_metrics_epoch() | ||||
|  | ||||
|     def validation_step(self, batch, batch_idx): | ||||
|         return self.loss_forward(batch) | ||||
|  | ||||
|     def test_step(self, batch, batch_idx): | ||||
|         return self.loss_forward(batch) | ||||
							
								
								
									
										112
									
								
								prototorch/y/architectures/comparison.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								prototorch/y/architectures/comparison.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,112 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from dataclasses import dataclass, field | ||||
| from typing import Callable, Dict | ||||
|  | ||||
| import torch | ||||
| from prototorch.core.distances import euclidean_distance | ||||
| from prototorch.core.initializers import ( | ||||
|     AbstractLinearTransformInitializer, | ||||
|     EyeLinearTransformInitializer, | ||||
| ) | ||||
| from prototorch.nn.wrappers import LambdaLayer | ||||
| from prototorch.y.architectures.base import BaseYArchitecture | ||||
| from torch import Tensor | ||||
| from torch.nn.parameter import Parameter | ||||
|  | ||||
|  | ||||
| class SimpleComparisonMixin(BaseYArchitecture): | ||||
|     """ | ||||
|     Simple Comparison | ||||
|  | ||||
|     A comparison layer that only uses the positions of the components and the batch for dissimilarity computation. | ||||
|     """ | ||||
|  | ||||
|     # HyperParameters | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     @dataclass | ||||
|     class HyperParameters(BaseYArchitecture.HyperParameters): | ||||
|         """ | ||||
|         comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance. | ||||
|         comparison_args: Keyword arguments for the comparison function. Default: {}. | ||||
|         """ | ||||
|         comparison_fn: Callable = euclidean_distance | ||||
|         comparison_args: dict = field(default_factory=lambda: dict()) | ||||
|  | ||||
|         comparison_parameters: dict = field(default_factory=lambda: dict()) | ||||
|  | ||||
|     # Steps | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     def init_comparison(self, hparams: HyperParameters): | ||||
|         self.comparison_layer = LambdaLayer( | ||||
|             fn=hparams.comparison_fn, | ||||
|             **hparams.comparison_args, | ||||
|         ) | ||||
|  | ||||
|         self.comparison_kwargs: dict[str, Tensor] = dict() | ||||
|  | ||||
|     def comparison(self, batch, components): | ||||
|         comp_tensor, _ = components | ||||
|         batch_tensor, _ = batch | ||||
|  | ||||
|         comp_tensor = comp_tensor.unsqueeze(1) | ||||
|  | ||||
|         distances = self.comparison_layer( | ||||
|             batch_tensor, | ||||
|             comp_tensor, | ||||
|             **self.comparison_kwargs, | ||||
|         ) | ||||
|  | ||||
|         return distances | ||||
|  | ||||
|  | ||||
| class OmegaComparisonMixin(SimpleComparisonMixin): | ||||
|     """ | ||||
|     Omega Comparison | ||||
|  | ||||
|     A comparison layer that uses the positions of the components and the batch for dissimilarity computation. | ||||
|     """ | ||||
|  | ||||
|     _omega: torch.Tensor | ||||
|  | ||||
|     # HyperParameters | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     @dataclass | ||||
|     class HyperParameters(SimpleComparisonMixin.HyperParameters): | ||||
|         """ | ||||
|         input_dim: Necessary Field: The dimensionality of the input. | ||||
|         latent_dim: The dimensionality of the latent space. Default: 2. | ||||
|         omega_initializer: The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer. | ||||
|         """ | ||||
|         input_dim: int | None = None | ||||
|         latent_dim: int = 2 | ||||
|         omega_initializer: type[ | ||||
|             AbstractLinearTransformInitializer] = EyeLinearTransformInitializer | ||||
|  | ||||
|     # Steps | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     def init_comparison(self, hparams: HyperParameters) -> None: | ||||
|         super().init_comparison(hparams) | ||||
|  | ||||
|         # Initialize the omega matrix | ||||
|         if hparams.input_dim is None: | ||||
|             raise ValueError("input_dim must be specified.") | ||||
|         else: | ||||
|             omega = hparams.omega_initializer().generate( | ||||
|                 hparams.input_dim, | ||||
|                 hparams.latent_dim, | ||||
|             ) | ||||
|             self.register_parameter("_omega", Parameter(omega)) | ||||
|             self.comparison_kwargs = dict(omega=self._omega) | ||||
|  | ||||
|     # Properties | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     @property | ||||
|     def omega_matrix(self): | ||||
|         return self._omega.detach().cpu() | ||||
|  | ||||
|     @property | ||||
|     def lambda_matrix(self): | ||||
|         omega = self._omega.detach() | ||||
|         lam = omega @ omega.T | ||||
|         return lam.detach().cpu() | ||||
							
								
								
									
										29
									
								
								prototorch/y/architectures/competition.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								prototorch/y/architectures/competition.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| from dataclasses import dataclass | ||||
|  | ||||
| from prototorch.core.competitions import WTAC | ||||
| from prototorch.y.architectures.base import BaseYArchitecture | ||||
|  | ||||
|  | ||||
| class WTACompetitionMixin(BaseYArchitecture): | ||||
|     """ | ||||
|     Winner Take All Competition | ||||
|  | ||||
|     A competition layer that uses the winner-take-all strategy. | ||||
|     """ | ||||
|  | ||||
|     # HyperParameters | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     @dataclass | ||||
|     class HyperParameters(BaseYArchitecture.HyperParameters): | ||||
|         """ | ||||
|         No hyperparameters. | ||||
|         """ | ||||
|  | ||||
|     # Steps | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     def init_inference(self, hparams: HyperParameters): | ||||
|         self.competition_layer = WTAC() | ||||
|  | ||||
|     def inference(self, comparison_measures, components): | ||||
|         comp_labels = components[1] | ||||
|         return self.competition_layer(comparison_measures, comp_labels) | ||||
							
								
								
									
										53
									
								
								prototorch/y/architectures/components.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								prototorch/y/architectures/components.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| from dataclasses import dataclass | ||||
|  | ||||
| from prototorch.core.components import LabeledComponents | ||||
| from prototorch.core.initializers import ( | ||||
|     AbstractComponentsInitializer, | ||||
|     LabelsInitializer, | ||||
| ) | ||||
| from prototorch.y import BaseYArchitecture | ||||
|  | ||||
|  | ||||
| class SupervisedArchitecture(BaseYArchitecture): | ||||
|     """ | ||||
|     Supervised Architecture | ||||
|  | ||||
|     An architecture that uses labeled Components as component Layer. | ||||
|     """ | ||||
|     components_layer: LabeledComponents | ||||
|  | ||||
|     # HyperParameters | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     @dataclass | ||||
|     class HyperParameters: | ||||
|         """ | ||||
|         distribution: A valid prototype distribution. No default possible. | ||||
|         components_initializer: An implementation of AbstractComponentsInitializer. No default possible. | ||||
|         """ | ||||
|         distribution: "dict[str, int]" | ||||
|         component_initializer: AbstractComponentsInitializer | ||||
|  | ||||
|     # Steps | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     def init_components(self, hparams: HyperParameters): | ||||
|         self.components_layer = LabeledComponents( | ||||
|             distribution=hparams.distribution, | ||||
|             components_initializer=hparams.component_initializer, | ||||
|             labels_initializer=LabelsInitializer(), | ||||
|         ) | ||||
|  | ||||
|     # Properties | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     @property | ||||
|     def prototypes(self): | ||||
|         """ | ||||
|         Returns the position of the prototypes. | ||||
|         """ | ||||
|         return self.components_layer.components.detach().cpu() | ||||
|  | ||||
|     @property | ||||
|     def prototype_labels(self): | ||||
|         """ | ||||
|         Returns the labels of the prototypes. | ||||
|         """ | ||||
|         return self.components_layer.labels.detach().cpu() | ||||
							
								
								
									
										42
									
								
								prototorch/y/architectures/loss.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								prototorch/y/architectures/loss.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| from dataclasses import dataclass, field | ||||
|  | ||||
| from prototorch.core.losses import GLVQLoss | ||||
| from prototorch.y.architectures.base import BaseYArchitecture | ||||
|  | ||||
|  | ||||
| class GLVQLossMixin(BaseYArchitecture): | ||||
|     """ | ||||
|     GLVQ Loss | ||||
|  | ||||
|     A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss. | ||||
|     """ | ||||
|  | ||||
|     # HyperParameters | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     @dataclass | ||||
|     class HyperParameters(BaseYArchitecture.HyperParameters): | ||||
|         """ | ||||
|         margin: The margin of the GLVQ loss. Default: 0.0. | ||||
|         transfer_fn: Transfer function to use. Default: sigmoid_beta. | ||||
|         transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}. | ||||
|         """ | ||||
|         margin: float = 0.0 | ||||
|  | ||||
|         transfer_fn: str = "sigmoid_beta" | ||||
|         transfer_args: dict = field(default_factory=lambda: dict(beta=10.0)) | ||||
|  | ||||
|     # Steps | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     def init_loss(self, hparams: HyperParameters): | ||||
|         self.loss_layer = GLVQLoss( | ||||
|             margin=hparams.margin, | ||||
|             transfer_fn=hparams.transfer_fn, | ||||
|             **hparams.transfer_args, | ||||
|         ) | ||||
|  | ||||
|     def loss(self, comparison_measures, batch, components): | ||||
|         target = batch[1] | ||||
|         comp_labels = components[1] | ||||
|         loss = self.loss_layer(comparison_measures, target, comp_labels) | ||||
|         self.log('loss', loss) | ||||
|         return loss | ||||
							
								
								
									
										86
									
								
								prototorch/y/architectures/optimization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								prototorch/y/architectures/optimization.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| from dataclasses import dataclass, field | ||||
| from typing import Type | ||||
|  | ||||
| import torch | ||||
| from prototorch.y import BaseYArchitecture | ||||
| from torch.nn.parameter import Parameter | ||||
|  | ||||
|  | ||||
| class SingleLearningRateMixin(BaseYArchitecture): | ||||
|     """ | ||||
|     Single Learning Rate | ||||
|  | ||||
|     All parameters are updated with a single learning rate. | ||||
|     """ | ||||
|  | ||||
|     # HyperParameters | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     @dataclass | ||||
|     class HyperParameters(BaseYArchitecture.HyperParameters): | ||||
|         """ | ||||
|         lr: The learning rate. Default: 0.1. | ||||
|         optimizer: The optimizer to use. Default: torch.optim.Adam. | ||||
|         """ | ||||
|         lr: float = 0.1 | ||||
|         optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam | ||||
|  | ||||
|     # Steps | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     def __init__(self, hparams: HyperParameters) -> None: | ||||
|         super().__init__(hparams) | ||||
|         self.lr = hparams.lr | ||||
|         self.optimizer = hparams.optimizer | ||||
|  | ||||
|     # Hooks | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     def configure_optimizers(self): | ||||
|         return self.optimizer(self.parameters(), lr=self.lr)  # type: ignore | ||||
|  | ||||
|  | ||||
| class MultipleLearningRateMixin(BaseYArchitecture): | ||||
|     """ | ||||
|     Multiple Learning Rates | ||||
|  | ||||
|     Define Different Learning Rates for different parameters. | ||||
|     """ | ||||
|  | ||||
|     # HyperParameters | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     @dataclass | ||||
|     class HyperParameters(BaseYArchitecture.HyperParameters): | ||||
|         """ | ||||
|         lr: The learning rate. Default: 0.1. | ||||
|         optimizer: The optimizer to use. Default: torch.optim.Adam. | ||||
|         """ | ||||
|         lr: dict = field(default_factory=lambda: dict()) | ||||
|         optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam | ||||
|  | ||||
|     # Steps | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     def __init__(self, hparams: HyperParameters) -> None: | ||||
|         super().__init__(hparams) | ||||
|         self.lr = hparams.lr | ||||
|         self.optimizer = hparams.optimizer | ||||
|  | ||||
|     # Hooks | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     def configure_optimizers(self): | ||||
|         optimizers = [] | ||||
|         for name, lr in self.lr.items(): | ||||
|             if not hasattr(self, name): | ||||
|                 raise ValueError(f"{name} is not a parameter of {self}") | ||||
|             else: | ||||
|                 model_part = getattr(self, name) | ||||
|                 if isinstance(model_part, Parameter): | ||||
|                     optimizers.append( | ||||
|                         self.optimizer( | ||||
|                             [model_part], | ||||
|                             lr=lr,  # type: ignore | ||||
|                         )) | ||||
|                 elif hasattr(model_part, "parameters"): | ||||
|                     optimizers.append( | ||||
|                         self.optimizer( | ||||
|                             model_part.parameters(), | ||||
|                             lr=lr,  # type: ignore | ||||
|                         )) | ||||
|         return optimizers | ||||
							
								
								
									
										149
									
								
								prototorch/y/callbacks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								prototorch/y/callbacks.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,149 @@ | ||||
| import warnings | ||||
| from typing import Optional, Type | ||||
|  | ||||
| import numpy as np | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| import torchmetrics | ||||
| from matplotlib import pyplot as plt | ||||
| from prototorch.models.vis import Vis2DAbstract | ||||
| from prototorch.utils.utils import mesh2d | ||||
| from prototorch.y.architectures.base import BaseYArchitecture | ||||
| from prototorch.y.library.gmlvq import GMLVQ | ||||
| from pytorch_lightning.loggers import TensorBoardLogger | ||||
|  | ||||
| DIVERGING_COLOR_MAPS = [ | ||||
|     'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn', | ||||
|     'Spectral', 'coolwarm', 'bwr', 'seismic' | ||||
| ] | ||||
|  | ||||
|  | ||||
| class LogTorchmetricCallback(pl.Callback): | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         name, | ||||
|         metric: Type[torchmetrics.Metric], | ||||
|         on="prediction", | ||||
|         **metric_kwargs, | ||||
|     ) -> None: | ||||
|         self.name = name | ||||
|         self.metric = metric | ||||
|         self.metric_kwargs = metric_kwargs | ||||
|         self.on = on | ||||
|  | ||||
|     def setup( | ||||
|         self, | ||||
|         trainer: pl.Trainer, | ||||
|         pl_module: BaseYArchitecture, | ||||
|         stage: Optional[str] = None, | ||||
|     ) -> None: | ||||
|         if self.on == "prediction": | ||||
|             pl_module.register_torchmetric( | ||||
|                 self.name, | ||||
|                 self.metric, | ||||
|                 **self.metric_kwargs, | ||||
|             ) | ||||
|         else: | ||||
|             raise ValueError(f"{self.on} is no valid metric hook") | ||||
|  | ||||
|  | ||||
| class VisGLVQ2D(Vis2DAbstract): | ||||
|  | ||||
|     def visualize(self, pl_module): | ||||
|         protos = pl_module.prototypes | ||||
|         plabels = pl_module.prototype_labels | ||||
|         x_train, y_train = self.x_train, self.y_train | ||||
|         ax = self.setup_ax() | ||||
|         self.plot_protos(ax, protos, plabels) | ||||
|         if x_train is not None: | ||||
|             self.plot_data(ax, x_train, y_train) | ||||
|             mesh_input, xx, yy = mesh2d( | ||||
|                 np.vstack([x_train, protos]), | ||||
|                 self.border, | ||||
|                 self.resolution, | ||||
|             ) | ||||
|         else: | ||||
|             mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution) | ||||
|         _components = pl_module.components_layer.components | ||||
|         mesh_input = torch.from_numpy(mesh_input).type_as(_components) | ||||
|         y_pred = pl_module.predict(mesh_input) | ||||
|         y_pred = y_pred.cpu().reshape(xx.shape) | ||||
|         ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) | ||||
|  | ||||
|  | ||||
| class VisGMLVQ2D(Vis2DAbstract): | ||||
|  | ||||
|     def __init__(self, *args, ev_proj=True, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.ev_proj = ev_proj | ||||
|  | ||||
|     def visualize(self, pl_module): | ||||
|         protos = pl_module.prototypes | ||||
|         plabels = pl_module.prototype_labels | ||||
|         x_train, y_train = self.x_train, self.y_train | ||||
|         device = pl_module.device | ||||
|         omega = pl_module._omega.detach() | ||||
|         lam = omega @ omega.T | ||||
|         u, _, _ = torch.pca_lowrank(lam, q=2) | ||||
|         with torch.no_grad(): | ||||
|             x_train = torch.Tensor(x_train).to(device) | ||||
|             x_train = x_train @ u | ||||
|             x_train = x_train.cpu().detach() | ||||
|         if self.show_protos: | ||||
|             with torch.no_grad(): | ||||
|                 protos = torch.Tensor(protos).to(device) | ||||
|                 protos = protos @ u | ||||
|                 protos = protos.cpu().detach() | ||||
|         ax = self.setup_ax() | ||||
|         self.plot_data(ax, x_train, y_train) | ||||
|         if self.show_protos: | ||||
|             self.plot_protos(ax, protos, plabels) | ||||
|  | ||||
|  | ||||
| class PlotLambdaMatrixToTensorboard(pl.Callback): | ||||
|  | ||||
|     def __init__(self, cmap='seismic') -> None: | ||||
|         super().__init__() | ||||
|         self.cmap = cmap | ||||
|  | ||||
|         if self.cmap not in DIVERGING_COLOR_MAPS and type(self.cmap) is str: | ||||
|             warnings.warn( | ||||
|                 f"{self.cmap} is not a diverging color map. We recommend to use one of the following: {DIVERGING_COLOR_MAPS}" | ||||
|             ) | ||||
|  | ||||
|     def on_train_start(self, trainer, pl_module: GMLVQ): | ||||
|         self.plot_lambda(trainer, pl_module) | ||||
|  | ||||
|     def on_train_epoch_end(self, trainer, pl_module: GMLVQ): | ||||
|         self.plot_lambda(trainer, pl_module) | ||||
|  | ||||
|     def plot_lambda(self, trainer, pl_module: GMLVQ): | ||||
|  | ||||
|         self.fig, self.ax = plt.subplots(1, 1) | ||||
|  | ||||
|         # plot lambda matrix | ||||
|         l_matrix = pl_module.lambda_matrix | ||||
|  | ||||
|         # normalize lambda matrix | ||||
|         l_matrix = l_matrix / torch.max(torch.abs(l_matrix)) | ||||
|  | ||||
|         # plot lambda matrix | ||||
|         self.ax.imshow(l_matrix.detach().numpy(), self.cmap, vmin=-1, vmax=1) | ||||
|  | ||||
|         self.fig.colorbar(self.ax.images[-1]) | ||||
|  | ||||
|         # add title | ||||
|         self.ax.set_title('Lambda Matrix') | ||||
|  | ||||
|         # add to tensorboard | ||||
|         if isinstance(trainer.logger, TensorBoardLogger): | ||||
|             trainer.logger.experiment.add_figure( | ||||
|                 f"lambda_matrix", | ||||
|                 self.fig, | ||||
|                 trainer.global_step, | ||||
|             ) | ||||
|         else: | ||||
|             warnings.warn( | ||||
|                 f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead." | ||||
|             ) | ||||
							
								
								
									
										5
									
								
								prototorch/y/library/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								prototorch/y/library/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| from .glvq import GLVQ | ||||
|  | ||||
| __all__ = [ | ||||
|     "GLVQ", | ||||
| ] | ||||
							
								
								
									
										35
									
								
								prototorch/y/library/glvq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								prototorch/y/library/glvq.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| from dataclasses import dataclass | ||||
|  | ||||
| from prototorch.y import ( | ||||
|     SimpleComparisonMixin, | ||||
|     SingleLearningRateMixin, | ||||
|     SupervisedArchitecture, | ||||
|     WTACompetitionMixin, | ||||
| ) | ||||
| from prototorch.y.architectures.loss import GLVQLossMixin | ||||
|  | ||||
|  | ||||
| class GLVQ( | ||||
|         SupervisedArchitecture, | ||||
|         SimpleComparisonMixin, | ||||
|         GLVQLossMixin, | ||||
|         WTACompetitionMixin, | ||||
|         SingleLearningRateMixin, | ||||
| ): | ||||
|     """ | ||||
|     Generalized Learning Vector Quantization (GLVQ) | ||||
|  | ||||
|     A GLVQ architecture that uses the winner-take-all strategy and the GLVQ loss. | ||||
|     """ | ||||
|  | ||||
|     @dataclass | ||||
|     class HyperParameters( | ||||
|             SimpleComparisonMixin.HyperParameters, | ||||
|             SingleLearningRateMixin.HyperParameters, | ||||
|             GLVQLossMixin.HyperParameters, | ||||
|             WTACompetitionMixin.HyperParameters, | ||||
|             SupervisedArchitecture.HyperParameters, | ||||
|     ): | ||||
|         """ | ||||
|         No hyperparameters. | ||||
|         """ | ||||
							
								
								
									
										50
									
								
								prototorch/y/library/gmlvq.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								prototorch/y/library/gmlvq.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from dataclasses import dataclass, field | ||||
| from typing import Callable | ||||
|  | ||||
| import torch | ||||
| from prototorch.core.distances import omega_distance | ||||
| from prototorch.y import ( | ||||
|     GLVQLossMixin, | ||||
|     MultipleLearningRateMixin, | ||||
|     OmegaComparisonMixin, | ||||
|     SupervisedArchitecture, | ||||
|     WTACompetitionMixin, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class GMLVQ( | ||||
|         SupervisedArchitecture, | ||||
|         OmegaComparisonMixin, | ||||
|         GLVQLossMixin, | ||||
|         WTACompetitionMixin, | ||||
|         MultipleLearningRateMixin, | ||||
| ): | ||||
|     """ | ||||
|     Generalized Matrix Learning Vector Quantization (GMLVQ) | ||||
|  | ||||
|     A GMLVQ architecture that uses the winner-take-all strategy and the GLVQ loss. | ||||
|     """ | ||||
|     # HyperParameters | ||||
|     # ---------------------------------------------------------------------------------------------------- | ||||
|     @dataclass | ||||
|     class HyperParameters( | ||||
|             MultipleLearningRateMixin.HyperParameters, | ||||
|             OmegaComparisonMixin.HyperParameters, | ||||
|             GLVQLossMixin.HyperParameters, | ||||
|             WTACompetitionMixin.HyperParameters, | ||||
|             SupervisedArchitecture.HyperParameters, | ||||
|     ): | ||||
|         """ | ||||
|         comparison_fn: The comparison / dissimilarity function to use. Override Default: omega_distance. | ||||
|         comparison_args: Keyword arguments for the comparison function. Override Default: {}. | ||||
|         """ | ||||
|         comparison_fn: Callable = omega_distance | ||||
|         comparison_args: dict = field(default_factory=lambda: dict()) | ||||
|         optimizer: type[torch.optim.Optimizer] = torch.optim.Adam | ||||
|  | ||||
|         lr: dict = field(default_factory=lambda: dict( | ||||
|             components_layer=0.1, | ||||
|             _omega=0.5, | ||||
|         )) | ||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @@ -55,7 +55,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS | ||||
|  | ||||
| setup( | ||||
|     name=safe_name("prototorch_" + PLUGIN_NAME), | ||||
|     version="0.5.2", | ||||
|     version="1.0.0-a2", | ||||
|     description="Pre-packaged prototype-based " | ||||
|     "machine learning models using ProtoTorch and PyTorch-Lightning.", | ||||
|     long_description=long_description, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user