diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index 7e0b3f4..f6886c6 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -1,10 +1,17 @@ """Abstract classes to be inherited by prototorch models.""" +import logging + import pytorch_lightning as pl import torch +import torch.nn.functional as F import torchmetrics from prototorch.core.competitions import WTAC -from prototorch.core.components import Components, LabeledComponents +from prototorch.core.components import ( + AbstractComponents, + Components, + LabeledComponents, +) from prototorch.core.distances import euclidean_distance from prototorch.core.initializers import ( LabelsInitializer, @@ -32,7 +39,7 @@ class ProtoTorchBolt(pl.LightningModule): self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict()) def configure_optimizers(self): - optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr) + optimizer = self.optimizer(self.parameters(), lr=self.hparams["lr"]) if self.lr_scheduler is not None: scheduler = self.lr_scheduler(optimizer, **self.lr_scheduler_kwargs) @@ -45,7 +52,10 @@ class ProtoTorchBolt(pl.LightningModule): return optimizer def reconfigure_optimizers(self): - self.trainer.strategy.setup_optimizers(self.trainer) + if self.trainer: + self.trainer.strategy.setup_optimizers(self.trainer) + else: + logging.warning("No trainer to reconfigure optimizers!") def __repr__(self): surep = super().__repr__() @@ -55,6 +65,7 @@ class ProtoTorchBolt(pl.LightningModule): class PrototypeModel(ProtoTorchBolt): + proto_layer: AbstractComponents def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -77,16 +88,17 @@ class PrototypeModel(ProtoTorchBolt): def add_prototypes(self, *args, **kwargs): self.proto_layer.add_components(*args, **kwargs) - self.hparams.distribution = self.proto_layer.distribution + self.hparams["distribution"] = self.proto_layer.distribution self.reconfigure_optimizers() def remove_prototypes(self, indices): self.proto_layer.remove_components(indices) - self.hparams.distribution = self.proto_layer.distribution + self.hparams["distribution"] = self.proto_layer.distribution self.reconfigure_optimizers() class UnsupervisedPrototypeModel(PrototypeModel): + proto_layer: Components def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -95,7 +107,7 @@ class UnsupervisedPrototypeModel(PrototypeModel): prototypes_initializer = kwargs.get("prototypes_initializer", None) if prototypes_initializer is not None: self.proto_layer = Components( - self.hparams.num_prototypes, + self.hparams["num_prototypes"], initializer=prototypes_initializer, ) @@ -110,6 +122,7 @@ class UnsupervisedPrototypeModel(PrototypeModel): class SupervisedPrototypeModel(PrototypeModel): + proto_layer: LabeledComponents def __init__(self, hparams, skip_proto_layer=False, **kwargs): super().__init__(hparams, **kwargs) @@ -129,13 +142,13 @@ class SupervisedPrototypeModel(PrototypeModel): labels_initializer=labels_initializer, ) proto_shape = self.proto_layer.components.shape[1:] - self.hparams.initialized_proto_shape = proto_shape + self.hparams["initialized_proto_shape"] = proto_shape else: # when restoring a checkpointed model self.proto_layer = LabeledComponents( distribution=distribution, components_initializer=ZerosCompInitializer( - self.hparams.initialized_proto_shape), + self.hparams["initialized_proto_shape"]), ) self.competition_layer = WTAC() @@ -156,7 +169,7 @@ class SupervisedPrototypeModel(PrototypeModel): distances = self.compute_distances(x) _, plabels = self.proto_layer() winning = stratified_min_pooling(distances, plabels) - y_pred = torch.nn.functional.softmin(winning, dim=1) + y_pred = F.softmin(winning, dim=1) return y_pred def predict_from_distances(self, distances): @@ -209,8 +222,10 @@ class NonGradientMixin(ProtoTorchMixin): class ImagePrototypesMixin(ProtoTorchMixin): """Mixin for models with image prototypes.""" + proto_layer: Components + components: torch.Tensor - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, outputs, batch, batch_idx): """Constrain the components to the range [0, 1] by clamping after updates.""" self.proto_layer.components.data.clamp_(0.0, 1.0) diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index 0986f61..b58ccc1 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -1,25 +1,30 @@ """Lightning Callbacks.""" import logging +from typing import TYPE_CHECKING import pytorch_lightning as pl import torch -from prototorch.core.components import Components from prototorch.core.initializers import LiteralCompInitializer from .extras import ConnectionTopology +if TYPE_CHECKING: + from prototorch.models import GLVQ, GrowingNeuralGas + class PruneLoserPrototypes(pl.Callback): - def __init__(self, - threshold=0.01, - idle_epochs=10, - prune_quota_per_epoch=-1, - frequency=1, - replace=False, - prototypes_initializer=None, - verbose=False): + def __init__( + self, + threshold=0.01, + idle_epochs=10, + prune_quota_per_epoch=-1, + frequency=1, + replace=False, + prototypes_initializer=None, + verbose=False, + ): self.threshold = threshold # minimum win ratio self.idle_epochs = idle_epochs # epochs to wait before pruning self.prune_quota_per_epoch = prune_quota_per_epoch @@ -28,7 +33,7 @@ class PruneLoserPrototypes(pl.Callback): self.verbose = verbose self.prototypes_initializer = prototypes_initializer - def on_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, pl_module: "GLVQ"): if (trainer.current_epoch + 1) < self.idle_epochs: return None if (trainer.current_epoch + 1) % self.frequency: @@ -43,27 +48,29 @@ class PruneLoserPrototypes(pl.Callback): prune_labels = prune_labels[:self.prune_quota_per_epoch] if len(to_prune) > 0: - if self.verbose: - print(f"\nPrototype win ratios: {ratios}") - print(f"Pruning prototypes at: {to_prune}") - print(f"Corresponding labels are: {prune_labels.tolist()}") + logging.debug(f"\nPrototype win ratios: {ratios}") + logging.debug(f"Pruning prototypes at: {to_prune}") + logging.debug(f"Corresponding labels are: {prune_labels.tolist()}") + cur_num_protos = pl_module.num_prototypes pl_module.remove_prototypes(indices=to_prune) + if self.replace: labels, counts = torch.unique(prune_labels, sorted=True, return_counts=True) distribution = dict(zip(labels.tolist(), counts.tolist())) - if self.verbose: - print(f"Re-adding pruned prototypes...") - print(f"distribution={distribution}") + + logging.info(f"Re-adding pruned prototypes...") + logging.debug(f"distribution={distribution}") + pl_module.add_prototypes( distribution=distribution, components_initializer=self.prototypes_initializer) new_num_protos = pl_module.num_prototypes - if self.verbose: - print(f"`num_prototypes` changed from {cur_num_protos} " - f"to {new_num_protos}.") + + logging.info(f"`num_prototypes` changed from {cur_num_protos} " + f"to {new_num_protos}.") return True @@ -74,11 +81,11 @@ class PrototypeConvergence(pl.Callback): self.idle_epochs = idle_epochs # epochs to wait self.verbose = verbose - def on_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, pl_module): if (trainer.current_epoch + 1) < self.idle_epochs: return None - if self.verbose: - print("Stopping...") + + logging.info("Stopping...") # TODO return True @@ -96,12 +103,16 @@ class GNGCallback(pl.Callback): self.reduction = reduction self.freq = freq - def on_epoch_end(self, trainer: pl.Trainer, pl_module): + def on_train_epoch_end( + self, + trainer: pl.Trainer, + pl_module: "GrowingNeuralGas", + ): if (trainer.current_epoch + 1) % self.freq == 0: # Get information errors = pl_module.errors topology: ConnectionTopology = pl_module.topology_layer - components: Components = pl_module.proto_layer.components + components = pl_module.proto_layer.components # Insertion point worst = torch.argmax(errors) @@ -121,8 +132,9 @@ class GNGCallback(pl.Callback): # Add component pl_module.proto_layer.add_components( - None, - initializer=LiteralCompInitializer(new_component.unsqueeze(0))) + 1, + initializer=LiteralCompInitializer(new_component.unsqueeze(0)), + ) # Adjust Topology topology.add_prototype() diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 29f4c3d..81ddaef 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -34,9 +34,9 @@ class GLVQ(SupervisedPrototypeModel): # Loss self.loss = GLVQLoss( - margin=self.hparams.margin, - transfer_fn=self.hparams.transfer_fn, - beta=self.hparams.transfer_beta, + margin=self.hparams["margin"], + transfer_fn=self.hparams["transfer_fn"], + beta=self.hparams["transfer_beta"], ) # def on_save_checkpoint(self, checkpoint): @@ -48,7 +48,7 @@ class GLVQ(SupervisedPrototypeModel): "prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device)) - def on_epoch_start(self): + def on_train_epoch_start(self): self.initialize_prototype_win_ratios() def log_prototype_win_ratios(self, distances): @@ -125,11 +125,11 @@ class SiameseGLVQ(GLVQ): def configure_optimizers(self): proto_opt = self.optimizer(self.proto_layer.parameters(), - lr=self.hparams.proto_lr) + lr=self.hparams["proto_lr"]) # Only add a backbone optimizer if backbone has trainable parameters bb_params = list(self.backbone.parameters()) if (bb_params): - bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr) + bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"]) optimizers = [proto_opt, bb_opt] else: optimizers = [proto_opt] @@ -199,12 +199,13 @@ class GRLVQ(SiameseGLVQ): TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise. """ + _relevances: torch.Tensor def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) # Additional parameters - relevances = torch.ones(self.hparams.input_dim, device=self.device) + relevances = torch.ones(self.hparams["input_dim"], device=self.device) self.register_parameter("_relevances", Parameter(relevances)) # Override the backbone @@ -233,8 +234,8 @@ class SiameseGMLVQ(SiameseGLVQ): omega_initializer = kwargs.get("omega_initializer", EyeLinearTransformInitializer()) self.backbone = LinearTransform( - self.hparams.input_dim, - self.hparams.latent_dim, + self.hparams["input_dim"], + self.hparams["latent_dim"], initializer=omega_initializer, ) @@ -244,7 +245,7 @@ class SiameseGMLVQ(SiameseGLVQ): @property def lambda_matrix(self): - omega = self.backbone.weight # (input_dim, latent_dim) + omega = self.backbone.weights # (input_dim, latent_dim) lam = omega @ omega.T return lam.detach().cpu() @@ -257,6 +258,9 @@ class GMLVQ(GLVQ): """ + # Parameters + _omega: torch.Tensor + def __init__(self, hparams, **kwargs): distance_fn = kwargs.pop("distance_fn", omega_distance) super().__init__(hparams, distance_fn=distance_fn, **kwargs) @@ -264,8 +268,8 @@ class GMLVQ(GLVQ): # Additional parameters omega_initializer = kwargs.get("omega_initializer", EyeLinearTransformInitializer()) - omega = omega_initializer.generate(self.hparams.input_dim, - self.hparams.latent_dim) + omega = omega_initializer.generate(self.hparams["input_dim"], + self.hparams["latent_dim"]) self.register_parameter("_omega", Parameter(omega)) self.backbone = LambdaLayer(lambda x: x @ self._omega, name="omega matrix") @@ -299,8 +303,8 @@ class LGMLVQ(GMLVQ): # Re-register `_omega` to override the one from the super class. omega = torch.randn( self.num_prototypes, - self.hparams.input_dim, - self.hparams.latent_dim, + self.hparams["input_dim"], + self.hparams["latent_dim"], device=self.device, ) self.register_parameter("_omega", Parameter(omega)) @@ -316,23 +320,27 @@ class GTLVQ(LGMLVQ): omega_initializer = kwargs.get("omega_initializer") if omega_initializer is not None: - subspace = omega_initializer.generate(self.hparams.input_dim, - self.hparams.latent_dim) - omega = torch.repeat_interleave(subspace.unsqueeze(0), - self.num_prototypes, - dim=0) + subspace = omega_initializer.generate( + self.hparams["input_dim"], + self.hparams["latent_dim"], + ) + omega = torch.repeat_interleave( + subspace.unsqueeze(0), + self.num_prototypes, + dim=0, + ) else: omega = torch.rand( self.num_prototypes, - self.hparams.input_dim, - self.hparams.latent_dim, + self.hparams["input_dim"], + self.hparams["latent_dim"], device=self.device, ) # Re-register `_omega` to override the one from the super class. self.register_parameter("_omega", Parameter(omega)) - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, outputs, batch, batch_idx): with torch.no_grad(): self._omega.copy_(orthogonalization(self._omega)) @@ -389,7 +397,7 @@ class ImageGTLVQ(ImagePrototypesMixin, GTLVQ): """ - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, outputs, batch, batch_idx): """Constrain the components to the range [0, 1] by clamping after updates.""" self.proto_layer.components.data.clamp_(0.0, 1.0) with torch.no_grad(): diff --git a/prototorch/models/knn.py b/prototorch/models/knn.py index 081ea55..d277206 100644 --- a/prototorch/models/knn.py +++ b/prototorch/models/knn.py @@ -37,10 +37,7 @@ class KNN(SupervisedPrototypeModel): def training_step(self, train_batch, batch_idx, optimizer_idx=None): return 1 # skip training step - def on_train_batch_start(self, - train_batch, - batch_idx, - dataloader_idx=None): + def on_train_batch_start(self, train_batch, batch_idx): warnings.warn("k-NN has no training, skipping!") return -1 diff --git a/prototorch/models/lvq.py b/prototorch/models/lvq.py index 0dd02ec..aa893a2 100644 --- a/prototorch/models/lvq.py +++ b/prototorch/models/lvq.py @@ -1,5 +1,7 @@ """LVQ models that are optimized using non-gradient methods.""" +import logging + from prototorch.core.losses import _get_dp_dm from prototorch.nn.activations import get_activation from prototorch.nn.wrappers import LambdaLayer @@ -30,8 +32,8 @@ class LVQ1(NonGradientMixin, GLVQ): self.proto_layer.load_state_dict({"_components": updated_protos}, strict=False) - print(f"dis={dis}") - print(f"y={y}") + logging.debug(f"dis={dis}") + logging.debug(f"y={y}") # Logging self.log_acc(dis, y, tag="train_acc") @@ -74,8 +76,7 @@ class MedianLVQ(NonGradientMixin, GLVQ): """ - def __init__(self, hparams, verbose=True, **kwargs): - self.verbose = verbose + def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) self.transfer_layer = LambdaLayer( @@ -116,8 +117,7 @@ class MedianLVQ(NonGradientMixin, GLVQ): _protos[i] = xk _lower_bound = self.lower_bound(x, y, _protos, plabels, gamma) if _lower_bound > lower_bound: - if self.verbose: - print(f"Updating prototype {i} to data {k}...") + logging.debug(f"Updating prototype {i} to data {k}...") self.proto_layer.load_state_dict({"_components": _protos}, strict=False) break diff --git a/prototorch/models/probabilistic.py b/prototorch/models/probabilistic.py index ab3f141..79da5d9 100644 --- a/prototorch/models/probabilistic.py +++ b/prototorch/models/probabilistic.py @@ -37,17 +37,24 @@ class ProbabilisticLVQ(GLVQ): def __init__(self, hparams, rejection_confidence=0.0, **kwargs): super().__init__(hparams, **kwargs) - self.conditional_distribution = None self.rejection_confidence = rejection_confidence + self._conditional_distribution = None def forward(self, x): distances = self.compute_distances(x) + conditional = self.conditional_distribution(distances) prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes, device=self.device) posterior = conditional * prior + plabels = self.proto_layer._labels - y_pred = stratified_sum_pooling(posterior, plabels) + if isinstance(plabels, torch.LongTensor) or isinstance( + plabels, torch.cuda.LongTensor): # type: ignore + y_pred = stratified_sum_pooling(posterior, plabels) # type: ignore + else: + raise ValueError("Labels must be LongTensor.") + return y_pred def predict(self, x): @@ -64,6 +71,12 @@ class ProbabilisticLVQ(GLVQ): loss = batch_loss.sum() return loss + def conditional_distribution(self, distances): + """Conditional distribution of distances.""" + if self._conditional_distribution is None: + raise ValueError("Conditional distribution is not set.") + return self._conditional_distribution(distances) + class SLVQ(ProbabilisticLVQ): """Soft Learning Vector Quantization.""" @@ -75,7 +88,7 @@ class SLVQ(ProbabilisticLVQ): self.hparams.setdefault("variance", 1.0) variance = self.hparams.get("variance") - self.conditional_distribution = GaussianPrior(variance) + self._conditional_distribution = GaussianPrior(variance) self.loss = LossLayer(nllr_loss) @@ -89,7 +102,7 @@ class RSLVQ(ProbabilisticLVQ): self.hparams.setdefault("variance", 1.0) variance = self.hparams.get("variance") - self.conditional_distribution = GaussianPrior(variance) + self._conditional_distribution = GaussianPrior(variance) self.loss = LossLayer(rslvq_loss) diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index 0a477c1..8833de2 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -17,6 +17,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): TODO Allow non-2D grids """ + _grid: torch.Tensor def __init__(self, hparams, **kwargs): h, w = hparams.get("shape") @@ -92,10 +93,10 @@ class NeuralGas(UnsupervisedPrototypeModel): self.hparams.setdefault("age_limit", 10) self.hparams.setdefault("lm", 1) - self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm) + self.energy_layer = NeuralGasEnergy(lm=self.hparams["lm"]) self.topology_layer = ConnectionTopology( - agelimit=self.hparams.age_limit, - num_prototypes=self.hparams.num_prototypes, + agelimit=self.hparams["age_limit"], + num_prototypes=self.hparams["num_prototypes"], ) def training_step(self, train_batch, batch_idx): @@ -108,12 +109,9 @@ class NeuralGas(UnsupervisedPrototypeModel): self.log("loss", loss) return loss - # def training_epoch_end(self, training_step_outputs): - # print(f"{self.trainer.lr_schedulers}") - # print(f"{self.trainer.lr_schedulers[0]['scheduler'].optimizer}") - class GrowingNeuralGas(NeuralGas): + errors: torch.Tensor def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) @@ -123,7 +121,10 @@ class GrowingNeuralGas(NeuralGas): self.hparams.setdefault("insert_reduction", 0.1) self.hparams.setdefault("insert_freq", 10) - errors = torch.zeros(self.hparams.num_prototypes, device=self.device) + errors = torch.zeros( + self.hparams["num_prototypes"], + device=self.device, + ) self.register_buffer("errors", errors) def training_step(self, train_batch, _batch_idx): @@ -138,7 +139,7 @@ class GrowingNeuralGas(NeuralGas): dp = d * mask self.errors += torch.sum(dp * dp) - self.errors *= self.hparams.step_reduction + self.errors *= self.hparams["step_reduction"] self.topology_layer(d) self.log("loss", loss) @@ -147,7 +148,7 @@ class GrowingNeuralGas(NeuralGas): def configure_callbacks(self): return [ GNGCallback( - reduction=self.hparams.insert_reduction, - freq=self.hparams.insert_freq, + reduction=self.hparams["insert_reduction"], + freq=self.hparams["insert_freq"], ) ] diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 6c16b4d..5e4dd95 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -1,5 +1,8 @@ """Visualization Callbacks.""" +import warnings +from typing import Sized + import numpy as np import pytorch_lightning as pl import torch @@ -7,6 +10,7 @@ import torchvision from matplotlib import pyplot as plt from prototorch.utils.colors import get_colors, get_legend_handles from prototorch.utils.utils import mesh2d +from pytorch_lightning.loggers import TensorBoardLogger from torch.utils.data import DataLoader, Dataset @@ -33,8 +37,13 @@ class Vis2DAbstract(pl.Callback): if data: if isinstance(data, Dataset): - x, y = next(iter(DataLoader(data, batch_size=len(data)))) - elif isinstance(data, torch.utils.data.DataLoader): + if isinstance(data, Sized): + x, y = next(iter(DataLoader(data, batch_size=len(data)))) + else: + # TODO: Add support for non-sized datasets + raise NotImplementedError( + "Data must be a dataset with a __len__ method.") + elif isinstance(data, DataLoader): x = torch.tensor([]) y = torch.tensor([]) for x_b, y_b in data: @@ -122,7 +131,7 @@ class Vis2DAbstract(pl.Callback): else: plt.show(block=self.block) - def on_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, pl_module): if not self.precheck(trainer): return True self.visualize(pl_module) @@ -131,6 +140,9 @@ class Vis2DAbstract(pl.Callback): def on_train_end(self, trainer, pl_module): plt.close() + def visualize(self, pl_module): + raise NotImplementedError + class VisGLVQ2D(Vis2DAbstract): @@ -291,30 +303,45 @@ class VisImgComp(Vis2DAbstract): self.add_embedding = add_embedding self.embedding_data = embedding_data - def on_train_start(self, trainer, pl_module): - tb = pl_module.logger.experiment - if self.add_embedding: - ind = np.random.choice(len(self.x_train), - size=self.embedding_data, - replace=False) - data = self.x_train[ind] - tb.add_embedding(data.view(len(ind), -1), - label_img=data, - global_step=None, - tag="Data Embedding", - metadata=self.y_train[ind], - metadata_header=None) + def on_train_start(self, _, pl_module): + if isinstance(pl_module.logger, TensorBoardLogger): + tb = pl_module.logger.experiment - if self.random_data: - ind = np.random.choice(len(self.x_train), - size=self.random_data, - replace=False) - data = self.x_train[ind] - grid = torchvision.utils.make_grid(data, nrow=self.num_columns) - tb.add_image(tag="Data", - img_tensor=grid, - global_step=None, - dataformats=self.dataformats) + # Add embedding + if self.add_embedding: + if self.x_train is not None and self.y_train is not None: + ind = np.random.choice(len(self.x_train), + size=self.embedding_data, + replace=False) + data = self.x_train[ind] + tb.add_embedding(data.view(len(ind), -1), + label_img=data, + global_step=None, + tag="Data Embedding", + metadata=self.y_train[ind], + metadata_header=None) + else: + raise ValueError("No data for add embedding flag") + + # Random Data + if self.random_data: + if self.x_train is not None: + ind = np.random.choice(len(self.x_train), + size=self.random_data, + replace=False) + data = self.x_train[ind] + grid = torchvision.utils.make_grid(data, + nrow=self.num_columns) + tb.add_image(tag="Data", + img_tensor=grid, + global_step=None, + dataformats=self.dataformats) + else: + raise ValueError("No data for random data flag") + + else: + warnings.warn( + f"TensorBoardLogger is required, got {type(pl_module.logger)}") def add_to_tensorboard(self, trainer, pl_module): tb = pl_module.logger.experiment