chore: minor updates and version updates

This commit is contained in:
Alexander Engelsberger 2022-05-17 12:00:52 +02:00
parent bccef8bef0
commit c00513ae0d
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
8 changed files with 184 additions and 111 deletions

View File

@ -1,10 +1,17 @@
"""Abstract classes to be inherited by prototorch models."""
import logging
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics
from prototorch.core.competitions import WTAC
from prototorch.core.components import Components, LabeledComponents
from prototorch.core.components import (
AbstractComponents,
Components,
LabeledComponents,
)
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
LabelsInitializer,
@ -32,7 +39,7 @@ class ProtoTorchBolt(pl.LightningModule):
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
optimizer = self.optimizer(self.parameters(), lr=self.hparams["lr"])
if self.lr_scheduler is not None:
scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs)
@ -45,7 +52,10 @@ class ProtoTorchBolt(pl.LightningModule):
return optimizer
def reconfigure_optimizers(self):
self.trainer.strategy.setup_optimizers(self.trainer)
if self.trainer:
self.trainer.strategy.setup_optimizers(self.trainer)
else:
logging.warning("No trainer to reconfigure optimizers!")
def __repr__(self):
surep = super().__repr__()
@ -55,6 +65,7 @@ class ProtoTorchBolt(pl.LightningModule):
class PrototypeModel(ProtoTorchBolt):
proto_layer: AbstractComponents
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@ -77,16 +88,17 @@ class PrototypeModel(ProtoTorchBolt):
def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
self.hparams.distribution = self.proto_layer.distribution
self.hparams["distribution"] = self.proto_layer.distribution
self.reconfigure_optimizers()
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.hparams.distribution = self.proto_layer.distribution
self.hparams["distribution"] = self.proto_layer.distribution
self.reconfigure_optimizers()
class UnsupervisedPrototypeModel(PrototypeModel):
proto_layer: Components
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@ -95,7 +107,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
prototypes_initializer = kwargs.get("prototypes_initializer", None)
if prototypes_initializer is not None:
self.proto_layer = Components(
self.hparams.num_prototypes,
self.hparams["num_prototypes"],
initializer=prototypes_initializer,
)
@ -110,6 +122,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
class SupervisedPrototypeModel(PrototypeModel):
proto_layer: LabeledComponents
def __init__(self, hparams, skip_proto_layer=False, **kwargs):
super().__init__(hparams, **kwargs)
@ -129,13 +142,13 @@ class SupervisedPrototypeModel(PrototypeModel):
labels_initializer=labels_initializer,
)
proto_shape = self.proto_layer.components.shape[1:]
self.hparams.initialized_proto_shape = proto_shape
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.proto_layer = LabeledComponents(
distribution=distribution,
components_initializer=ZerosCompInitializer(
self.hparams.initialized_proto_shape),
self.hparams["initialized_proto_shape"]),
)
self.competition_layer = WTAC()
@ -156,7 +169,7 @@ class SupervisedPrototypeModel(PrototypeModel):
distances = self.compute_distances(x)
_, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels)
y_pred = torch.nn.functional.softmin(winning, dim=1)
y_pred = F.softmin(winning, dim=1)
return y_pred
def predict_from_distances(self, distances):
@ -209,8 +222,10 @@ class NonGradientMixin(ProtoTorchMixin):
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)

View File

@ -1,25 +1,30 @@
"""Lightning Callbacks."""
import logging
from typing import TYPE_CHECKING
import pytorch_lightning as pl
import torch
from prototorch.core.components import Components
from prototorch.core.initializers import LiteralCompInitializer
from .extras import ConnectionTopology
if TYPE_CHECKING:
from prototorch.models import GLVQ, GrowingNeuralGas
class PruneLoserPrototypes(pl.Callback):
def __init__(self,
threshold=0.01,
idle_epochs=10,
prune_quota_per_epoch=-1,
frequency=1,
replace=False,
prototypes_initializer=None,
verbose=False):
def __init__(
self,
threshold=0.01,
idle_epochs=10,
prune_quota_per_epoch=-1,
frequency=1,
replace=False,
prototypes_initializer=None,
verbose=False,
):
self.threshold = threshold # minimum win ratio
self.idle_epochs = idle_epochs # epochs to wait before pruning
self.prune_quota_per_epoch = prune_quota_per_epoch
@ -28,7 +33,7 @@ class PruneLoserPrototypes(pl.Callback):
self.verbose = verbose
self.prototypes_initializer = prototypes_initializer
def on_epoch_end(self, trainer, pl_module):
def on_train_epoch_end(self, trainer, pl_module: "GLVQ"):
if (trainer.current_epoch + 1) < self.idle_epochs:
return None
if (trainer.current_epoch + 1) % self.frequency:
@ -43,27 +48,29 @@ class PruneLoserPrototypes(pl.Callback):
prune_labels = prune_labels[:self.prune_quota_per_epoch]
if len(to_prune) > 0:
if self.verbose:
print(f"\nPrototype win ratios: {ratios}")
print(f"Pruning prototypes at: {to_prune}")
print(f"Corresponding labels are: {prune_labels.tolist()}")
logging.debug(f"\nPrototype win ratios: {ratios}")
logging.debug(f"Pruning prototypes at: {to_prune}")
logging.debug(f"Corresponding labels are: {prune_labels.tolist()}")
cur_num_protos = pl_module.num_prototypes
pl_module.remove_prototypes(indices=to_prune)
if self.replace:
labels, counts = torch.unique(prune_labels,
sorted=True,
return_counts=True)
distribution = dict(zip(labels.tolist(), counts.tolist()))
if self.verbose:
print(f"Re-adding pruned prototypes...")
print(f"distribution={distribution}")
logging.info(f"Re-adding pruned prototypes...")
logging.debug(f"distribution={distribution}")
pl_module.add_prototypes(
distribution=distribution,
components_initializer=self.prototypes_initializer)
new_num_protos = pl_module.num_prototypes
if self.verbose:
print(f"`num_prototypes` changed from {cur_num_protos} "
f"to {new_num_protos}.")
logging.info(f"`num_prototypes` changed from {cur_num_protos} "
f"to {new_num_protos}.")
return True
@ -74,11 +81,11 @@ class PrototypeConvergence(pl.Callback):
self.idle_epochs = idle_epochs # epochs to wait
self.verbose = verbose
def on_epoch_end(self, trainer, pl_module):
def on_train_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.idle_epochs:
return None
if self.verbose:
print("Stopping...")
logging.info("Stopping...")
# TODO
return True
@ -96,12 +103,16 @@ class GNGCallback(pl.Callback):
self.reduction = reduction
self.freq = freq
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
def on_train_epoch_end(
self,
trainer: pl.Trainer,
pl_module: "GrowingNeuralGas",
):
if (trainer.current_epoch + 1) % self.freq == 0:
# Get information
errors = pl_module.errors
topology: ConnectionTopology = pl_module.topology_layer
components: Components = pl_module.proto_layer.components
components = pl_module.proto_layer.components
# Insertion point
worst = torch.argmax(errors)
@ -121,8 +132,9 @@ class GNGCallback(pl.Callback):
# Add component
pl_module.proto_layer.add_components(
None,
initializer=LiteralCompInitializer(new_component.unsqueeze(0)))
1,
initializer=LiteralCompInitializer(new_component.unsqueeze(0)),
)
# Adjust Topology
topology.add_prototype()

View File

@ -34,9 +34,9 @@ class GLVQ(SupervisedPrototypeModel):
# Loss
self.loss = GLVQLoss(
margin=self.hparams.margin,
transfer_fn=self.hparams.transfer_fn,
beta=self.hparams.transfer_beta,
margin=self.hparams["margin"],
transfer_fn=self.hparams["transfer_fn"],
beta=self.hparams["transfer_beta"],
)
# def on_save_checkpoint(self, checkpoint):
@ -48,7 +48,7 @@ class GLVQ(SupervisedPrototypeModel):
"prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device))
def on_epoch_start(self):
def on_train_epoch_start(self):
self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances):
@ -125,11 +125,11 @@ class SiameseGLVQ(GLVQ):
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams.proto_lr)
lr=self.hparams["proto_lr"])
# Only add a backbone optimizer if backbone has trainable parameters
bb_params = list(self.backbone.parameters())
if (bb_params):
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
optimizers = [proto_opt, bb_opt]
else:
optimizers = [proto_opt]
@ -199,12 +199,13 @@ class GRLVQ(SiameseGLVQ):
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
"""
_relevances: torch.Tensor
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Additional parameters
relevances = torch.ones(self.hparams.input_dim, device=self.device)
relevances = torch.ones(self.hparams["input_dim"], device=self.device)
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone
@ -233,8 +234,8 @@ class SiameseGMLVQ(SiameseGLVQ):
omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer())
self.backbone = LinearTransform(
self.hparams.input_dim,
self.hparams.latent_dim,
self.hparams["input_dim"],
self.hparams["latent_dim"],
initializer=omega_initializer,
)
@ -244,7 +245,7 @@ class SiameseGMLVQ(SiameseGLVQ):
@property
def lambda_matrix(self):
omega = self.backbone.weight # (input_dim, latent_dim)
omega = self.backbone.weights # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
@ -257,6 +258,9 @@ class GMLVQ(GLVQ):
"""
# Parameters
_omega: torch.Tensor
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
@ -264,8 +268,8 @@ class GMLVQ(GLVQ):
# Additional parameters
omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer())
omega = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim)
omega = omega_initializer.generate(self.hparams["input_dim"],
self.hparams["latent_dim"])
self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega,
name="omega matrix")
@ -299,8 +303,8 @@ class LGMLVQ(GMLVQ):
# Re-register `_omega` to override the one from the super class.
omega = torch.randn(
self.num_prototypes,
self.hparams.input_dim,
self.hparams.latent_dim,
self.hparams["input_dim"],
self.hparams["latent_dim"],
device=self.device,
)
self.register_parameter("_omega", Parameter(omega))
@ -316,23 +320,27 @@ class GTLVQ(LGMLVQ):
omega_initializer = kwargs.get("omega_initializer")
if omega_initializer is not None:
subspace = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim)
omega = torch.repeat_interleave(subspace.unsqueeze(0),
self.num_prototypes,
dim=0)
subspace = omega_initializer.generate(
self.hparams["input_dim"],
self.hparams["latent_dim"],
)
omega = torch.repeat_interleave(
subspace.unsqueeze(0),
self.num_prototypes,
dim=0,
)
else:
omega = torch.rand(
self.num_prototypes,
self.hparams.input_dim,
self.hparams.latent_dim,
self.hparams["input_dim"],
self.hparams["latent_dim"],
device=self.device,
)
# Re-register `_omega` to override the one from the super class.
self.register_parameter("_omega", Parameter(omega))
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega))
@ -389,7 +397,7 @@ class ImageGTLVQ(ImagePrototypesMixin, GTLVQ):
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
with torch.no_grad():

View File

@ -37,10 +37,7 @@ class KNN(SupervisedPrototypeModel):
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1 # skip training step
def on_train_batch_start(self,
train_batch,
batch_idx,
dataloader_idx=None):
def on_train_batch_start(self, train_batch, batch_idx):
warnings.warn("k-NN has no training, skipping!")
return -1

View File

@ -1,5 +1,7 @@
"""LVQ models that are optimized using non-gradient methods."""
import logging
from prototorch.core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer
@ -30,8 +32,8 @@ class LVQ1(NonGradientMixin, GLVQ):
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
print(f"dis={dis}")
print(f"y={y}")
logging.debug(f"dis={dis}")
logging.debug(f"y={y}")
# Logging
self.log_acc(dis, y, tag="train_acc")
@ -74,8 +76,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
"""
def __init__(self, hparams, verbose=True, **kwargs):
self.verbose = verbose
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.transfer_layer = LambdaLayer(
@ -116,8 +117,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
_protos[i] = xk
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
if _lower_bound > lower_bound:
if self.verbose:
print(f"Updating prototype {i} to data {k}...")
logging.debug(f"Updating prototype {i} to data {k}...")
self.proto_layer.load_state_dict({"_components": _protos},
strict=False)
break

View File

@ -37,17 +37,24 @@ class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs)
self.conditional_distribution = None
self.rejection_confidence = rejection_confidence
self._conditional_distribution = None
def forward(self, x):
distances = self.compute_distances(x)
conditional = self.conditional_distribution(distances)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
device=self.device)
posterior = conditional * prior
plabels = self.proto_layer._labels
y_pred = stratified_sum_pooling(posterior, plabels)
if isinstance(plabels, torch.LongTensor) or isinstance(
plabels, torch.cuda.LongTensor): # type: ignore
y_pred = stratified_sum_pooling(posterior, plabels) # type: ignore
else:
raise ValueError("Labels must be LongTensor.")
return y_pred
def predict(self, x):
@ -64,6 +71,12 @@ class ProbabilisticLVQ(GLVQ):
loss = batch_loss.sum()
return loss
def conditional_distribution(self, distances):
"""Conditional distribution of distances."""
if self._conditional_distribution is None:
raise ValueError("Conditional distribution is not set.")
return self._conditional_distribution(distances)
class SLVQ(ProbabilisticLVQ):
"""Soft Learning Vector Quantization."""
@ -75,7 +88,7 @@ class SLVQ(ProbabilisticLVQ):
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self.conditional_distribution = GaussianPrior(variance)
self._conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(nllr_loss)
@ -89,7 +102,7 @@ class RSLVQ(ProbabilisticLVQ):
self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance")
self.conditional_distribution = GaussianPrior(variance)
self._conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(rslvq_loss)

View File

@ -17,6 +17,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
TODO Allow non-2D grids
"""
_grid: torch.Tensor
def __init__(self, hparams, **kwargs):
h, w = hparams.get("shape")
@ -92,10 +93,10 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.hparams.setdefault("age_limit", 10)
self.hparams.setdefault("lm", 1)
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.energy_layer = NeuralGasEnergy(lm=self.hparams["lm"])
self.topology_layer = ConnectionTopology(
agelimit=self.hparams.age_limit,
num_prototypes=self.hparams.num_prototypes,
agelimit=self.hparams["age_limit"],
num_prototypes=self.hparams["num_prototypes"],
)
def training_step(self, train_batch, batch_idx):
@ -108,12 +109,9 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.log("loss", loss)
return loss
# def training_epoch_end(self, training_step_outputs):
# print(f"{self.trainer.lr_schedulers}")
# print(f"{self.trainer.lr_schedulers[0]['scheduler'].optimizer}")
class GrowingNeuralGas(NeuralGas):
errors: torch.Tensor
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
@ -123,7 +121,10 @@ class GrowingNeuralGas(NeuralGas):
self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10)
errors = torch.zeros(self.hparams.num_prototypes, device=self.device)
errors = torch.zeros(
self.hparams["num_prototypes"],
device=self.device,
)
self.register_buffer("errors", errors)
def training_step(self, train_batch, _batch_idx):
@ -138,7 +139,7 @@ class GrowingNeuralGas(NeuralGas):
dp = d * mask
self.errors += torch.sum(dp * dp)
self.errors *= self.hparams.step_reduction
self.errors *= self.hparams["step_reduction"]
self.topology_layer(d)
self.log("loss", loss)
@ -147,7 +148,7 @@ class GrowingNeuralGas(NeuralGas):
def configure_callbacks(self):
return [
GNGCallback(
reduction=self.hparams.insert_reduction,
freq=self.hparams.insert_freq,
reduction=self.hparams["insert_reduction"],
freq=self.hparams["insert_freq"],
)
]

View File

@ -1,5 +1,8 @@
"""Visualization Callbacks."""
import warnings
from typing import Sized
import numpy as np
import pytorch_lightning as pl
import torch
@ -7,6 +10,7 @@ import torchvision
from matplotlib import pyplot as plt
from prototorch.utils.colors import get_colors, get_legend_handles
from prototorch.utils.utils import mesh2d
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, Dataset
@ -33,8 +37,13 @@ class Vis2DAbstract(pl.Callback):
if data:
if isinstance(data, Dataset):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
elif isinstance(data, torch.utils.data.DataLoader):
if isinstance(data, Sized):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
else:
# TODO: Add support for non-sized datasets
raise NotImplementedError(
"Data must be a dataset with a __len__ method.")
elif isinstance(data, DataLoader):
x = torch.tensor([])
y = torch.tensor([])
for x_b, y_b in data:
@ -122,7 +131,7 @@ class Vis2DAbstract(pl.Callback):
else:
plt.show(block=self.block)
def on_epoch_end(self, trainer, pl_module):
def on_train_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
self.visualize(pl_module)
@ -131,6 +140,9 @@ class Vis2DAbstract(pl.Callback):
def on_train_end(self, trainer, pl_module):
plt.close()
def visualize(self, pl_module):
raise NotImplementedError
class VisGLVQ2D(Vis2DAbstract):
@ -291,30 +303,45 @@ class VisImgComp(Vis2DAbstract):
self.add_embedding = add_embedding
self.embedding_data = embedding_data
def on_train_start(self, trainer, pl_module):
tb = pl_module.logger.experiment
if self.add_embedding:
ind = np.random.choice(len(self.x_train),
size=self.embedding_data,
replace=False)
data = self.x_train[ind]
tb.add_embedding(data.view(len(ind), -1),
label_img=data,
global_step=None,
tag="Data Embedding",
metadata=self.y_train[ind],
metadata_header=None)
def on_train_start(self, _, pl_module):
if isinstance(pl_module.logger, TensorBoardLogger):
tb = pl_module.logger.experiment
if self.random_data:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data = self.x_train[ind]
grid = torchvision.utils.make_grid(data, nrow=self.num_columns)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=None,
dataformats=self.dataformats)
# Add embedding
if self.add_embedding:
if self.x_train is not None and self.y_train is not None:
ind = np.random.choice(len(self.x_train),
size=self.embedding_data,
replace=False)
data = self.x_train[ind]
tb.add_embedding(data.view(len(ind), -1),
label_img=data,
global_step=None,
tag="Data Embedding",
metadata=self.y_train[ind],
metadata_header=None)
else:
raise ValueError("No data for add embedding flag")
# Random Data
if self.random_data:
if self.x_train is not None:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data = self.x_train[ind]
grid = torchvision.utils.make_grid(data,
nrow=self.num_columns)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=None,
dataformats=self.dataformats)
else:
raise ValueError("No data for random data flag")
else:
warnings.warn(
f"TensorBoardLogger is required, got {type(pl_module.logger)}")
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment