chore: move mixins to seperate file
This commit is contained in:
		| @@ -22,7 +22,16 @@ from prototorch.nn.wrappers import LambdaLayer | ||||
|  | ||||
|  | ||||
| class ProtoTorchBolt(pl.LightningModule): | ||||
|     """All ProtoTorch models are ProtoTorch Bolts.""" | ||||
|     """All ProtoTorch models are ProtoTorch Bolts. | ||||
|  | ||||
|     hparams: | ||||
|         - lr: learning rate | ||||
|  | ||||
|     kwargs: | ||||
|         - optimizer: optimizer class | ||||
|         - lr_scheduler: learning rate scheduler class | ||||
|         - lr_scheduler_kwargs: learning rate scheduler kwargs | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, hparams, **kwargs): | ||||
|         super().__init__() | ||||
| @@ -65,6 +74,11 @@ class ProtoTorchBolt(pl.LightningModule): | ||||
|  | ||||
|  | ||||
| class PrototypeModel(ProtoTorchBolt): | ||||
|     """Abstract Prototype Model | ||||
|  | ||||
|     kwargs: | ||||
|         - distance_fn: distance function | ||||
|     """ | ||||
|     proto_layer: AbstractComponents | ||||
|  | ||||
|     def __init__(self, hparams, **kwargs): | ||||
| @@ -203,35 +217,3 @@ class SupervisedPrototypeModel(PrototypeModel): | ||||
|         accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int()) | ||||
|  | ||||
|         self.log("test_acc", accuracy) | ||||
|  | ||||
|  | ||||
| class ProtoTorchMixin(object): | ||||
|     """All mixins are ProtoTorchMixins.""" | ||||
|  | ||||
|  | ||||
| class NonGradientMixin(ProtoTorchMixin): | ||||
|     """Mixin for custom non-gradient optimization.""" | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.automatic_optimization = False | ||||
|  | ||||
|     def training_step(self, train_batch, batch_idx, optimizer_idx=None): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class ImagePrototypesMixin(ProtoTorchMixin): | ||||
|     """Mixin for models with image prototypes.""" | ||||
|     proto_layer: Components | ||||
|     components: torch.Tensor | ||||
|  | ||||
|     def on_train_batch_end(self, outputs, batch, batch_idx): | ||||
|         """Constrain the components to the range [0, 1] by clamping after updates.""" | ||||
|         self.proto_layer.components.data.clamp_(0.0, 1.0) | ||||
|  | ||||
|     def get_prototype_grid(self, num_columns=2, return_channels_last=True): | ||||
|         from torchvision.utils import make_grid | ||||
|         grid = make_grid(self.components, nrow=num_columns) | ||||
|         if return_channels_last: | ||||
|             grid = grid.permute((1, 2, 0)) | ||||
|         return grid.cpu() | ||||
|   | ||||
| @@ -40,8 +40,8 @@ class PruneLoserPrototypes(pl.Callback): | ||||
|             return None | ||||
|  | ||||
|         ratios = pl_module.prototype_win_ratios.mean(dim=0) | ||||
|         to_prune = torch.arange(len(ratios))[ratios < self.threshold] | ||||
|         to_prune = to_prune.tolist() | ||||
|         to_prune_tensor = torch.arange(len(ratios))[ratios < self.threshold] | ||||
|         to_prune = to_prune_tensor.tolist() | ||||
|         prune_labels = pl_module.prototype_labels[to_prune] | ||||
|         if self.prune_quota_per_epoch > 0: | ||||
|             to_prune = to_prune[:self.prune_quota_per_epoch] | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| import torch | ||||
| import torch.nn.functional as F | ||||
| import torchmetrics | ||||
| from prototorch.core.competitions import CBCC | ||||
| from prototorch.core.components import ReasoningComponents | ||||
| @@ -7,12 +8,13 @@ from prototorch.core.losses import MarginLoss | ||||
| from prototorch.core.similarities import euclidean_similarity | ||||
| from prototorch.nn.wrappers import LambdaLayer | ||||
|  | ||||
| from .abstract import ImagePrototypesMixin | ||||
| from .glvq import SiameseGLVQ | ||||
| from .mixins import ImagePrototypesMixin | ||||
|  | ||||
|  | ||||
| class CBC(SiameseGLVQ): | ||||
|     """Classification-By-Components.""" | ||||
|     proto_layer: ReasoningComponents | ||||
|  | ||||
|     def __init__(self, hparams, **kwargs): | ||||
|         super().__init__(hparams, skip_proto_layer=True, **kwargs) | ||||
| @@ -22,7 +24,7 @@ class CBC(SiameseGLVQ): | ||||
|         reasonings_initializer = kwargs.get("reasonings_initializer", | ||||
|                                             RandomReasoningsInitializer()) | ||||
|         self.components_layer = ReasoningComponents( | ||||
|             self.hparams.distribution, | ||||
|             self.hparams["distribution"], | ||||
|             components_initializer=components_initializer, | ||||
|             reasonings_initializer=reasonings_initializer, | ||||
|         ) | ||||
| @@ -32,7 +34,7 @@ class CBC(SiameseGLVQ): | ||||
|         # Namespace hook | ||||
|         self.proto_layer = self.components_layer | ||||
|  | ||||
|         self.loss = MarginLoss(self.hparams.margin) | ||||
|         self.loss = MarginLoss(self.hparams["margin"]) | ||||
|  | ||||
|     def forward(self, x): | ||||
|         components, reasonings = self.components_layer() | ||||
| @@ -48,7 +50,7 @@ class CBC(SiameseGLVQ): | ||||
|         x, y = batch | ||||
|         y_pred = self(x) | ||||
|         num_classes = self.num_classes | ||||
|         y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes) | ||||
|         y_true = F.one_hot(y.long(), num_classes=num_classes) | ||||
|         loss = self.loss(y_pred, y_true).mean() | ||||
|         return y_pred, loss | ||||
|  | ||||
|   | ||||
| @@ -17,8 +17,9 @@ from prototorch.core.transforms import LinearTransform | ||||
| from prototorch.nn.wrappers import LambdaLayer, LossLayer | ||||
| from torch.nn.parameter import Parameter | ||||
|  | ||||
| from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel | ||||
| from .abstract import SupervisedPrototypeModel | ||||
| from .extras import ltangent_distance, orthogonalization | ||||
| from .mixins import ImagePrototypesMixin | ||||
|  | ||||
|  | ||||
| class GLVQ(SupervisedPrototypeModel): | ||||
| @@ -46,19 +47,24 @@ class GLVQ(SupervisedPrototypeModel): | ||||
|     def initialize_prototype_win_ratios(self): | ||||
|         self.register_buffer( | ||||
|             "prototype_win_ratios", | ||||
|             torch.zeros(self.num_prototypes, device=self.device)) | ||||
|             torch.zeros(self.num_prototypes, device=self.device), | ||||
|         ) | ||||
|  | ||||
|     def on_train_epoch_start(self): | ||||
|         self.initialize_prototype_win_ratios() | ||||
|  | ||||
|     def log_prototype_win_ratios(self, distances): | ||||
|         batch_size = len(distances) | ||||
|         prototype_wc = torch.zeros(self.num_prototypes, | ||||
|                                    dtype=torch.long, | ||||
|                                    device=self.device) | ||||
|         wi, wc = torch.unique(distances.min(dim=-1).indices, | ||||
|                               sorted=True, | ||||
|                               return_counts=True) | ||||
|         prototype_wc = torch.zeros( | ||||
|             self.num_prototypes, | ||||
|             dtype=torch.long, | ||||
|             device=self.device, | ||||
|         ) | ||||
|         wi, wc = torch.unique( | ||||
|             distances.min(dim=-1).indices, | ||||
|             sorted=True, | ||||
|             return_counts=True, | ||||
|         ) | ||||
|         prototype_wc[wi] = wc | ||||
|         prototype_wr = prototype_wc / batch_size | ||||
|         self.prototype_win_ratios = torch.vstack([ | ||||
| @@ -81,14 +87,12 @@ class GLVQ(SupervisedPrototypeModel): | ||||
|         return train_loss | ||||
|  | ||||
|     def validation_step(self, batch, batch_idx): | ||||
|         # `model.eval()` and `torch.no_grad()` handled by pl | ||||
|         out, val_loss = self.shared_step(batch, batch_idx) | ||||
|         self.log("val_loss", val_loss) | ||||
|         self.log_acc(out, batch[-1], tag="val_acc") | ||||
|         return val_loss | ||||
|  | ||||
|     def test_step(self, batch, batch_idx): | ||||
|         # `model.eval()` and `torch.no_grad()` handled by pl | ||||
|         out, test_loss = self.shared_step(batch, batch_idx) | ||||
|         self.log_acc(out, batch[-1], tag="test_acc") | ||||
|         return test_loss | ||||
| @@ -99,10 +103,6 @@ class GLVQ(SupervisedPrototypeModel): | ||||
|             test_loss += batch_loss.item() | ||||
|         self.log("test_loss", test_loss) | ||||
|  | ||||
|     # TODO | ||||
|     # def predict_step(self, batch, batch_idx, dataloader_idx=None): | ||||
|     #     pass | ||||
|  | ||||
|  | ||||
| class SiameseGLVQ(GLVQ): | ||||
|     """GLVQ in a Siamese setting. | ||||
| @@ -113,19 +113,23 @@ class SiameseGLVQ(GLVQ): | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, | ||||
|                  hparams, | ||||
|                  backbone=torch.nn.Identity(), | ||||
|                  both_path_gradients=False, | ||||
|                  **kwargs): | ||||
|     def __init__( | ||||
|             self, | ||||
|             hparams, | ||||
|             backbone=torch.nn.Identity(), | ||||
|             both_path_gradients=False, | ||||
|             **kwargs, | ||||
|     ): | ||||
|         distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance) | ||||
|         super().__init__(hparams, distance_fn=distance_fn, **kwargs) | ||||
|         self.backbone = backbone | ||||
|         self.both_path_gradients = both_path_gradients | ||||
|  | ||||
|     def configure_optimizers(self): | ||||
|         proto_opt = self.optimizer(self.proto_layer.parameters(), | ||||
|                                    lr=self.hparams["proto_lr"]) | ||||
|         proto_opt = self.optimizer( | ||||
|             self.proto_layer.parameters(), | ||||
|             lr=self.hparams["proto_lr"], | ||||
|         ) | ||||
|         # Only add a backbone optimizer if backbone has trainable parameters | ||||
|         bb_params = list(self.backbone.parameters()) | ||||
|         if (bb_params): | ||||
| @@ -266,13 +270,19 @@ class GMLVQ(GLVQ): | ||||
|         super().__init__(hparams, distance_fn=distance_fn, **kwargs) | ||||
|  | ||||
|         # Additional parameters | ||||
|         omega_initializer = kwargs.get("omega_initializer", | ||||
|                                        EyeLinearTransformInitializer()) | ||||
|         omega = omega_initializer.generate(self.hparams["input_dim"], | ||||
|                                            self.hparams["latent_dim"]) | ||||
|         omega_initializer = kwargs.get( | ||||
|             "omega_initializer", | ||||
|             EyeLinearTransformInitializer(), | ||||
|         ) | ||||
|         omega = omega_initializer.generate( | ||||
|             self.hparams["input_dim"], | ||||
|             self.hparams["latent_dim"], | ||||
|         ) | ||||
|         self.register_parameter("_omega", Parameter(omega)) | ||||
|         self.backbone = LambdaLayer(lambda x: x @ self._omega, | ||||
|                                     name="omega matrix") | ||||
|         self.backbone = LambdaLayer( | ||||
|             lambda x: x @ self._omega, | ||||
|             name="omega matrix", | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def omega_matrix(self): | ||||
|   | ||||
| @@ -1,20 +1,21 @@ | ||||
| """LVQ models that are optimized using non-gradient methods.""" | ||||
|  | ||||
| import logging | ||||
| from collections import OrderedDict | ||||
|  | ||||
| from prototorch.core.losses import _get_dp_dm | ||||
| from prototorch.nn.activations import get_activation | ||||
| from prototorch.nn.wrappers import LambdaLayer | ||||
|  | ||||
| from .abstract import NonGradientMixin | ||||
| from .glvq import GLVQ | ||||
| from .mixins import NonGradientMixin | ||||
|  | ||||
|  | ||||
| class LVQ1(NonGradientMixin, GLVQ): | ||||
|     """Learning Vector Quantization 1.""" | ||||
|  | ||||
|     def training_step(self, train_batch, batch_idx, optimizer_idx=None): | ||||
|         protos, plables = self.proto_layer() | ||||
|         protos, plabels = self.proto_layer() | ||||
|         x, y = train_batch | ||||
|         dis = self.compute_distances(x) | ||||
|         # TODO Vectorized implementation | ||||
| @@ -28,9 +29,11 @@ class LVQ1(NonGradientMixin, GLVQ): | ||||
|             else: | ||||
|                 shift = protos[w] - xi | ||||
|             updated_protos = protos + 0.0 | ||||
|             updated_protos[w] = protos[w] + (self.hparams.lr * shift) | ||||
|             self.proto_layer.load_state_dict({"_components": updated_protos}, | ||||
|                                              strict=False) | ||||
|             updated_protos[w] = protos[w] + (self.hparams["lr"] * shift) | ||||
|             self.proto_layer.load_state_dict( | ||||
|                 OrderedDict(_components=updated_protos), | ||||
|                 strict=False, | ||||
|             ) | ||||
|  | ||||
|         logging.debug(f"dis={dis}") | ||||
|         logging.debug(f"y={y}") | ||||
| @@ -58,10 +61,12 @@ class LVQ21(NonGradientMixin, GLVQ): | ||||
|             shiftp = xi - protos[wp] | ||||
|             shiftn = protos[wn] - xi | ||||
|             updated_protos = protos + 0.0 | ||||
|             updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp) | ||||
|             updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn) | ||||
|             self.proto_layer.load_state_dict({"_components": updated_protos}, | ||||
|                                              strict=False) | ||||
|             updated_protos[wp] = protos[wp] + (self.hparams["lr"] * shiftp) | ||||
|             updated_protos[wn] = protos[wn] + (self.hparams["lr"] * shiftn) | ||||
|             self.proto_layer.load_state_dict( | ||||
|                 OrderedDict(_components=updated_protos), | ||||
|                 strict=False, | ||||
|             ) | ||||
|  | ||||
|         # Logging | ||||
|         self.log_acc(dis, y, tag="train_acc") | ||||
| @@ -80,14 +85,17 @@ class MedianLVQ(NonGradientMixin, GLVQ): | ||||
|         super().__init__(hparams, **kwargs) | ||||
|  | ||||
|         self.transfer_layer = LambdaLayer( | ||||
|             get_activation(self.hparams.transfer_fn)) | ||||
|             get_activation(self.hparams["transfer_fn"])) | ||||
|  | ||||
|     def _f(self, x, y, protos, plabels): | ||||
|         d = self.distance_layer(x, protos) | ||||
|         dp, dm = _get_dp_dm(d, y, plabels) | ||||
|         dp, dm = _get_dp_dm(d, y, plabels, with_indices=False) | ||||
|         mu = (dp - dm) / (dp + dm) | ||||
|         invmu = -1.0 * mu | ||||
|         f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0 | ||||
|         negative_mu = -1.0 * mu | ||||
|         f = self.transfer_layer( | ||||
|             negative_mu, | ||||
|             beta=self.hparams["transfer_beta"], | ||||
|         ) + 1.0 | ||||
|         return f | ||||
|  | ||||
|     def expectation(self, x, y, protos, plabels): | ||||
| @@ -118,8 +126,10 @@ class MedianLVQ(NonGradientMixin, GLVQ): | ||||
|                 _lower_bound = self.lower_bound(x, y, _protos, plabels, gamma) | ||||
|                 if _lower_bound > lower_bound: | ||||
|                     logging.debug(f"Updating prototype {i} to data {k}...") | ||||
|                     self.proto_layer.load_state_dict({"_components": _protos}, | ||||
|                                                      strict=False) | ||||
|                     self.proto_layer.load_state_dict( | ||||
|                         OrderedDict(_components=_protos), | ||||
|                         strict=False, | ||||
|                     ) | ||||
|                     break | ||||
|  | ||||
|         # Logging | ||||
|   | ||||
							
								
								
									
										35
									
								
								prototorch/models/mixins.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								prototorch/models/mixins.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| import pytorch_lightning as pl | ||||
| import torch | ||||
| from prototorch.core.components import Components | ||||
|  | ||||
|  | ||||
| class ProtoTorchMixin(pl.LightningModule): | ||||
|     """All mixins are ProtoTorchMixins.""" | ||||
|  | ||||
|  | ||||
| class NonGradientMixin(ProtoTorchMixin): | ||||
|     """Mixin for custom non-gradient optimization.""" | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.automatic_optimization = False | ||||
|  | ||||
|     def training_step(self, train_batch, batch_idx, optimizer_idx=None): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|  | ||||
| class ImagePrototypesMixin(ProtoTorchMixin): | ||||
|     """Mixin for models with image prototypes.""" | ||||
|     proto_layer: Components | ||||
|     components: torch.Tensor | ||||
|  | ||||
|     def on_train_batch_end(self, outputs, batch, batch_idx): | ||||
|         """Constrain the components to the range [0, 1] by clamping after updates.""" | ||||
|         self.proto_layer.components.data.clamp_(0.0, 1.0) | ||||
|  | ||||
|     def get_prototype_grid(self, num_columns=2, return_channels_last=True): | ||||
|         from torchvision.utils import make_grid | ||||
|         grid = make_grid(self.components, nrow=num_columns) | ||||
|         if return_channels_last: | ||||
|             grid = grid.permute((1, 2, 0)) | ||||
|         return grid.cpu() | ||||
| @@ -6,9 +6,10 @@ from prototorch.core.competitions import wtac | ||||
| from prototorch.core.distances import squared_euclidean_distance | ||||
| from prototorch.core.losses import NeuralGasEnergy | ||||
|  | ||||
| from .abstract import NonGradientMixin, UnsupervisedPrototypeModel | ||||
| from .abstract import UnsupervisedPrototypeModel | ||||
| from .callbacks import GNGCallback | ||||
| from .extras import ConnectionTopology | ||||
| from .mixins import NonGradientMixin | ||||
|  | ||||
|  | ||||
| class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user