chore: minor updates and version updates

This commit is contained in:
Alexander Engelsberger 2022-05-17 12:00:52 +02:00
parent bccef8bef0
commit c00513ae0d
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
8 changed files with 184 additions and 111 deletions

View File

@ -1,10 +1,17 @@
"""Abstract classes to be inherited by prototorch models.""" """Abstract classes to be inherited by prototorch models."""
import logging
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn.functional as F
import torchmetrics import torchmetrics
from prototorch.core.competitions import WTAC from prototorch.core.competitions import WTAC
from prototorch.core.components import Components, LabeledComponents from prototorch.core.components import (
AbstractComponents,
Components,
LabeledComponents,
)
from prototorch.core.distances import euclidean_distance from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import ( from prototorch.core.initializers import (
LabelsInitializer, LabelsInitializer,
@ -32,7 +39,7 @@ class ProtoTorchBolt(pl.LightningModule):
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict()) self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
def configure_optimizers(self): def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr) optimizer = self.optimizer(self.parameters(), lr=self.hparams["lr"])
if self.lr_scheduler is not None: if self.lr_scheduler is not None:
scheduler = self.lr_scheduler(optimizer, scheduler = self.lr_scheduler(optimizer,
**self.lr_scheduler_kwargs) **self.lr_scheduler_kwargs)
@ -45,7 +52,10 @@ class ProtoTorchBolt(pl.LightningModule):
return optimizer return optimizer
def reconfigure_optimizers(self): def reconfigure_optimizers(self):
if self.trainer:
self.trainer.strategy.setup_optimizers(self.trainer) self.trainer.strategy.setup_optimizers(self.trainer)
else:
logging.warning("No trainer to reconfigure optimizers!")
def __repr__(self): def __repr__(self):
surep = super().__repr__() surep = super().__repr__()
@ -55,6 +65,7 @@ class ProtoTorchBolt(pl.LightningModule):
class PrototypeModel(ProtoTorchBolt): class PrototypeModel(ProtoTorchBolt):
proto_layer: AbstractComponents
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@ -77,16 +88,17 @@ class PrototypeModel(ProtoTorchBolt):
def add_prototypes(self, *args, **kwargs): def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs) self.proto_layer.add_components(*args, **kwargs)
self.hparams.distribution = self.proto_layer.distribution self.hparams["distribution"] = self.proto_layer.distribution
self.reconfigure_optimizers() self.reconfigure_optimizers()
def remove_prototypes(self, indices): def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices) self.proto_layer.remove_components(indices)
self.hparams.distribution = self.proto_layer.distribution self.hparams["distribution"] = self.proto_layer.distribution
self.reconfigure_optimizers() self.reconfigure_optimizers()
class UnsupervisedPrototypeModel(PrototypeModel): class UnsupervisedPrototypeModel(PrototypeModel):
proto_layer: Components
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@ -95,7 +107,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
prototypes_initializer = kwargs.get("prototypes_initializer", None) prototypes_initializer = kwargs.get("prototypes_initializer", None)
if prototypes_initializer is not None: if prototypes_initializer is not None:
self.proto_layer = Components( self.proto_layer = Components(
self.hparams.num_prototypes, self.hparams["num_prototypes"],
initializer=prototypes_initializer, initializer=prototypes_initializer,
) )
@ -110,6 +122,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
class SupervisedPrototypeModel(PrototypeModel): class SupervisedPrototypeModel(PrototypeModel):
proto_layer: LabeledComponents
def __init__(self, hparams, skip_proto_layer=False, **kwargs): def __init__(self, hparams, skip_proto_layer=False, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@ -129,13 +142,13 @@ class SupervisedPrototypeModel(PrototypeModel):
labels_initializer=labels_initializer, labels_initializer=labels_initializer,
) )
proto_shape = self.proto_layer.components.shape[1:] proto_shape = self.proto_layer.components.shape[1:]
self.hparams.initialized_proto_shape = proto_shape self.hparams["initialized_proto_shape"] = proto_shape
else: else:
# when restoring a checkpointed model # when restoring a checkpointed model
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
distribution=distribution, distribution=distribution,
components_initializer=ZerosCompInitializer( components_initializer=ZerosCompInitializer(
self.hparams.initialized_proto_shape), self.hparams["initialized_proto_shape"]),
) )
self.competition_layer = WTAC() self.competition_layer = WTAC()
@ -156,7 +169,7 @@ class SupervisedPrototypeModel(PrototypeModel):
distances = self.compute_distances(x) distances = self.compute_distances(x)
_, plabels = self.proto_layer() _, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels) winning = stratified_min_pooling(distances, plabels)
y_pred = torch.nn.functional.softmin(winning, dim=1) y_pred = F.softmin(winning, dim=1)
return y_pred return y_pred
def predict_from_distances(self, distances): def predict_from_distances(self, distances):
@ -209,8 +222,10 @@ class NonGradientMixin(ProtoTorchMixin):
class ImagePrototypesMixin(ProtoTorchMixin): class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes.""" """Mixin for models with image prototypes."""
proto_layer: Components
components: torch.Tensor
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates.""" """Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0) self.proto_layer.components.data.clamp_(0.0, 1.0)

View File

@ -1,25 +1,30 @@
"""Lightning Callbacks.""" """Lightning Callbacks."""
import logging import logging
from typing import TYPE_CHECKING
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.core.components import Components
from prototorch.core.initializers import LiteralCompInitializer from prototorch.core.initializers import LiteralCompInitializer
from .extras import ConnectionTopology from .extras import ConnectionTopology
if TYPE_CHECKING:
from prototorch.models import GLVQ, GrowingNeuralGas
class PruneLoserPrototypes(pl.Callback): class PruneLoserPrototypes(pl.Callback):
def __init__(self, def __init__(
self,
threshold=0.01, threshold=0.01,
idle_epochs=10, idle_epochs=10,
prune_quota_per_epoch=-1, prune_quota_per_epoch=-1,
frequency=1, frequency=1,
replace=False, replace=False,
prototypes_initializer=None, prototypes_initializer=None,
verbose=False): verbose=False,
):
self.threshold = threshold # minimum win ratio self.threshold = threshold # minimum win ratio
self.idle_epochs = idle_epochs # epochs to wait before pruning self.idle_epochs = idle_epochs # epochs to wait before pruning
self.prune_quota_per_epoch = prune_quota_per_epoch self.prune_quota_per_epoch = prune_quota_per_epoch
@ -28,7 +33,7 @@ class PruneLoserPrototypes(pl.Callback):
self.verbose = verbose self.verbose = verbose
self.prototypes_initializer = prototypes_initializer self.prototypes_initializer = prototypes_initializer
def on_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module: "GLVQ"):
if (trainer.current_epoch + 1) < self.idle_epochs: if (trainer.current_epoch + 1) < self.idle_epochs:
return None return None
if (trainer.current_epoch + 1) % self.frequency: if (trainer.current_epoch + 1) % self.frequency:
@ -43,26 +48,28 @@ class PruneLoserPrototypes(pl.Callback):
prune_labels = prune_labels[:self.prune_quota_per_epoch] prune_labels = prune_labels[:self.prune_quota_per_epoch]
if len(to_prune) > 0: if len(to_prune) > 0:
if self.verbose: logging.debug(f"\nPrototype win ratios: {ratios}")
print(f"\nPrototype win ratios: {ratios}") logging.debug(f"Pruning prototypes at: {to_prune}")
print(f"Pruning prototypes at: {to_prune}") logging.debug(f"Corresponding labels are: {prune_labels.tolist()}")
print(f"Corresponding labels are: {prune_labels.tolist()}")
cur_num_protos = pl_module.num_prototypes cur_num_protos = pl_module.num_prototypes
pl_module.remove_prototypes(indices=to_prune) pl_module.remove_prototypes(indices=to_prune)
if self.replace: if self.replace:
labels, counts = torch.unique(prune_labels, labels, counts = torch.unique(prune_labels,
sorted=True, sorted=True,
return_counts=True) return_counts=True)
distribution = dict(zip(labels.tolist(), counts.tolist())) distribution = dict(zip(labels.tolist(), counts.tolist()))
if self.verbose:
print(f"Re-adding pruned prototypes...") logging.info(f"Re-adding pruned prototypes...")
print(f"distribution={distribution}") logging.debug(f"distribution={distribution}")
pl_module.add_prototypes( pl_module.add_prototypes(
distribution=distribution, distribution=distribution,
components_initializer=self.prototypes_initializer) components_initializer=self.prototypes_initializer)
new_num_protos = pl_module.num_prototypes new_num_protos = pl_module.num_prototypes
if self.verbose:
print(f"`num_prototypes` changed from {cur_num_protos} " logging.info(f"`num_prototypes` changed from {cur_num_protos} "
f"to {new_num_protos}.") f"to {new_num_protos}.")
return True return True
@ -74,11 +81,11 @@ class PrototypeConvergence(pl.Callback):
self.idle_epochs = idle_epochs # epochs to wait self.idle_epochs = idle_epochs # epochs to wait
self.verbose = verbose self.verbose = verbose
def on_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
if (trainer.current_epoch + 1) < self.idle_epochs: if (trainer.current_epoch + 1) < self.idle_epochs:
return None return None
if self.verbose:
print("Stopping...") logging.info("Stopping...")
# TODO # TODO
return True return True
@ -96,12 +103,16 @@ class GNGCallback(pl.Callback):
self.reduction = reduction self.reduction = reduction
self.freq = freq self.freq = freq
def on_epoch_end(self, trainer: pl.Trainer, pl_module): def on_train_epoch_end(
self,
trainer: pl.Trainer,
pl_module: "GrowingNeuralGas",
):
if (trainer.current_epoch + 1) % self.freq == 0: if (trainer.current_epoch + 1) % self.freq == 0:
# Get information # Get information
errors = pl_module.errors errors = pl_module.errors
topology: ConnectionTopology = pl_module.topology_layer topology: ConnectionTopology = pl_module.topology_layer
components: Components = pl_module.proto_layer.components components = pl_module.proto_layer.components
# Insertion point # Insertion point
worst = torch.argmax(errors) worst = torch.argmax(errors)
@ -121,8 +132,9 @@ class GNGCallback(pl.Callback):
# Add component # Add component
pl_module.proto_layer.add_components( pl_module.proto_layer.add_components(
None, 1,
initializer=LiteralCompInitializer(new_component.unsqueeze(0))) initializer=LiteralCompInitializer(new_component.unsqueeze(0)),
)
# Adjust Topology # Adjust Topology
topology.add_prototype() topology.add_prototype()

View File

@ -34,9 +34,9 @@ class GLVQ(SupervisedPrototypeModel):
# Loss # Loss
self.loss = GLVQLoss( self.loss = GLVQLoss(
margin=self.hparams.margin, margin=self.hparams["margin"],
transfer_fn=self.hparams.transfer_fn, transfer_fn=self.hparams["transfer_fn"],
beta=self.hparams.transfer_beta, beta=self.hparams["transfer_beta"],
) )
# def on_save_checkpoint(self, checkpoint): # def on_save_checkpoint(self, checkpoint):
@ -48,7 +48,7 @@ class GLVQ(SupervisedPrototypeModel):
"prototype_win_ratios", "prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device)) torch.zeros(self.num_prototypes, device=self.device))
def on_epoch_start(self): def on_train_epoch_start(self):
self.initialize_prototype_win_ratios() self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances): def log_prototype_win_ratios(self, distances):
@ -125,11 +125,11 @@ class SiameseGLVQ(GLVQ):
def configure_optimizers(self): def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(), proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams.proto_lr) lr=self.hparams["proto_lr"])
# Only add a backbone optimizer if backbone has trainable parameters # Only add a backbone optimizer if backbone has trainable parameters
bb_params = list(self.backbone.parameters()) bb_params = list(self.backbone.parameters())
if (bb_params): if (bb_params):
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr) bb_opt = self.optimizer(bb_params, lr=self.hparams["bb_lr"])
optimizers = [proto_opt, bb_opt] optimizers = [proto_opt, bb_opt]
else: else:
optimizers = [proto_opt] optimizers = [proto_opt]
@ -199,12 +199,13 @@ class GRLVQ(SiameseGLVQ):
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise. TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
""" """
_relevances: torch.Tensor
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Additional parameters # Additional parameters
relevances = torch.ones(self.hparams.input_dim, device=self.device) relevances = torch.ones(self.hparams["input_dim"], device=self.device)
self.register_parameter("_relevances", Parameter(relevances)) self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone # Override the backbone
@ -233,8 +234,8 @@ class SiameseGMLVQ(SiameseGLVQ):
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer()) EyeLinearTransformInitializer())
self.backbone = LinearTransform( self.backbone = LinearTransform(
self.hparams.input_dim, self.hparams["input_dim"],
self.hparams.latent_dim, self.hparams["latent_dim"],
initializer=omega_initializer, initializer=omega_initializer,
) )
@ -244,7 +245,7 @@ class SiameseGMLVQ(SiameseGLVQ):
@property @property
def lambda_matrix(self): def lambda_matrix(self):
omega = self.backbone.weight # (input_dim, latent_dim) omega = self.backbone.weights # (input_dim, latent_dim)
lam = omega @ omega.T lam = omega @ omega.T
return lam.detach().cpu() return lam.detach().cpu()
@ -257,6 +258,9 @@ class GMLVQ(GLVQ):
""" """
# Parameters
_omega: torch.Tensor
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance) distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
@ -264,8 +268,8 @@ class GMLVQ(GLVQ):
# Additional parameters # Additional parameters
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer()) EyeLinearTransformInitializer())
omega = omega_initializer.generate(self.hparams.input_dim, omega = omega_initializer.generate(self.hparams["input_dim"],
self.hparams.latent_dim) self.hparams["latent_dim"])
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega, self.backbone = LambdaLayer(lambda x: x @ self._omega,
name="omega matrix") name="omega matrix")
@ -299,8 +303,8 @@ class LGMLVQ(GMLVQ):
# Re-register `_omega` to override the one from the super class. # Re-register `_omega` to override the one from the super class.
omega = torch.randn( omega = torch.randn(
self.num_prototypes, self.num_prototypes,
self.hparams.input_dim, self.hparams["input_dim"],
self.hparams.latent_dim, self.hparams["latent_dim"],
device=self.device, device=self.device,
) )
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
@ -316,23 +320,27 @@ class GTLVQ(LGMLVQ):
omega_initializer = kwargs.get("omega_initializer") omega_initializer = kwargs.get("omega_initializer")
if omega_initializer is not None: if omega_initializer is not None:
subspace = omega_initializer.generate(self.hparams.input_dim, subspace = omega_initializer.generate(
self.hparams.latent_dim) self.hparams["input_dim"],
omega = torch.repeat_interleave(subspace.unsqueeze(0), self.hparams["latent_dim"],
)
omega = torch.repeat_interleave(
subspace.unsqueeze(0),
self.num_prototypes, self.num_prototypes,
dim=0) dim=0,
)
else: else:
omega = torch.rand( omega = torch.rand(
self.num_prototypes, self.num_prototypes,
self.hparams.input_dim, self.hparams["input_dim"],
self.hparams.latent_dim, self.hparams["latent_dim"],
device=self.device, device=self.device,
) )
# Re-register `_omega` to override the one from the super class. # Re-register `_omega` to override the one from the super class.
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad(): with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega)) self._omega.copy_(orthogonalization(self._omega))
@ -389,7 +397,7 @@ class ImageGTLVQ(ImagePrototypesMixin, GTLVQ):
""" """
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, outputs, batch, batch_idx):
"""Constrain the components to the range [0, 1] by clamping after updates.""" """Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0) self.proto_layer.components.data.clamp_(0.0, 1.0)
with torch.no_grad(): with torch.no_grad():

View File

@ -37,10 +37,7 @@ class KNN(SupervisedPrototypeModel):
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1 # skip training step return 1 # skip training step
def on_train_batch_start(self, def on_train_batch_start(self, train_batch, batch_idx):
train_batch,
batch_idx,
dataloader_idx=None):
warnings.warn("k-NN has no training, skipping!") warnings.warn("k-NN has no training, skipping!")
return -1 return -1

View File

@ -1,5 +1,7 @@
"""LVQ models that are optimized using non-gradient methods.""" """LVQ models that are optimized using non-gradient methods."""
import logging
from prototorch.core.losses import _get_dp_dm from prototorch.core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation from prototorch.nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer from prototorch.nn.wrappers import LambdaLayer
@ -30,8 +32,8 @@ class LVQ1(NonGradientMixin, GLVQ):
self.proto_layer.load_state_dict({"_components": updated_protos}, self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False) strict=False)
print(f"dis={dis}") logging.debug(f"dis={dis}")
print(f"y={y}") logging.debug(f"y={y}")
# Logging # Logging
self.log_acc(dis, y, tag="train_acc") self.log_acc(dis, y, tag="train_acc")
@ -74,8 +76,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
""" """
def __init__(self, hparams, verbose=True, **kwargs): def __init__(self, hparams, **kwargs):
self.verbose = verbose
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.transfer_layer = LambdaLayer( self.transfer_layer = LambdaLayer(
@ -116,8 +117,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
_protos[i] = xk _protos[i] = xk
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma) _lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
if _lower_bound > lower_bound: if _lower_bound > lower_bound:
if self.verbose: logging.debug(f"Updating prototype {i} to data {k}...")
print(f"Updating prototype {i} to data {k}...")
self.proto_layer.load_state_dict({"_components": _protos}, self.proto_layer.load_state_dict({"_components": _protos},
strict=False) strict=False)
break break

View File

@ -37,17 +37,24 @@ class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs): def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.conditional_distribution = None
self.rejection_confidence = rejection_confidence self.rejection_confidence = rejection_confidence
self._conditional_distribution = None
def forward(self, x): def forward(self, x):
distances = self.compute_distances(x) distances = self.compute_distances(x)
conditional = self.conditional_distribution(distances) conditional = self.conditional_distribution(distances)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes, prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
device=self.device) device=self.device)
posterior = conditional * prior posterior = conditional * prior
plabels = self.proto_layer._labels plabels = self.proto_layer._labels
y_pred = stratified_sum_pooling(posterior, plabels) if isinstance(plabels, torch.LongTensor) or isinstance(
plabels, torch.cuda.LongTensor): # type: ignore
y_pred = stratified_sum_pooling(posterior, plabels) # type: ignore
else:
raise ValueError("Labels must be LongTensor.")
return y_pred return y_pred
def predict(self, x): def predict(self, x):
@ -64,6 +71,12 @@ class ProbabilisticLVQ(GLVQ):
loss = batch_loss.sum() loss = batch_loss.sum()
return loss return loss
def conditional_distribution(self, distances):
"""Conditional distribution of distances."""
if self._conditional_distribution is None:
raise ValueError("Conditional distribution is not set.")
return self._conditional_distribution(distances)
class SLVQ(ProbabilisticLVQ): class SLVQ(ProbabilisticLVQ):
"""Soft Learning Vector Quantization.""" """Soft Learning Vector Quantization."""
@ -75,7 +88,7 @@ class SLVQ(ProbabilisticLVQ):
self.hparams.setdefault("variance", 1.0) self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance") variance = self.hparams.get("variance")
self.conditional_distribution = GaussianPrior(variance) self._conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(nllr_loss) self.loss = LossLayer(nllr_loss)
@ -89,7 +102,7 @@ class RSLVQ(ProbabilisticLVQ):
self.hparams.setdefault("variance", 1.0) self.hparams.setdefault("variance", 1.0)
variance = self.hparams.get("variance") variance = self.hparams.get("variance")
self.conditional_distribution = GaussianPrior(variance) self._conditional_distribution = GaussianPrior(variance)
self.loss = LossLayer(rslvq_loss) self.loss = LossLayer(rslvq_loss)

View File

@ -17,6 +17,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
TODO Allow non-2D grids TODO Allow non-2D grids
""" """
_grid: torch.Tensor
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
h, w = hparams.get("shape") h, w = hparams.get("shape")
@ -92,10 +93,10 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.hparams.setdefault("age_limit", 10) self.hparams.setdefault("age_limit", 10)
self.hparams.setdefault("lm", 1) self.hparams.setdefault("lm", 1)
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm) self.energy_layer = NeuralGasEnergy(lm=self.hparams["lm"])
self.topology_layer = ConnectionTopology( self.topology_layer = ConnectionTopology(
agelimit=self.hparams.age_limit, agelimit=self.hparams["age_limit"],
num_prototypes=self.hparams.num_prototypes, num_prototypes=self.hparams["num_prototypes"],
) )
def training_step(self, train_batch, batch_idx): def training_step(self, train_batch, batch_idx):
@ -108,12 +109,9 @@ class NeuralGas(UnsupervisedPrototypeModel):
self.log("loss", loss) self.log("loss", loss)
return loss return loss
# def training_epoch_end(self, training_step_outputs):
# print(f"{self.trainer.lr_schedulers}")
# print(f"{self.trainer.lr_schedulers[0]['scheduler'].optimizer}")
class GrowingNeuralGas(NeuralGas): class GrowingNeuralGas(NeuralGas):
errors: torch.Tensor
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
@ -123,7 +121,10 @@ class GrowingNeuralGas(NeuralGas):
self.hparams.setdefault("insert_reduction", 0.1) self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10) self.hparams.setdefault("insert_freq", 10)
errors = torch.zeros(self.hparams.num_prototypes, device=self.device) errors = torch.zeros(
self.hparams["num_prototypes"],
device=self.device,
)
self.register_buffer("errors", errors) self.register_buffer("errors", errors)
def training_step(self, train_batch, _batch_idx): def training_step(self, train_batch, _batch_idx):
@ -138,7 +139,7 @@ class GrowingNeuralGas(NeuralGas):
dp = d * mask dp = d * mask
self.errors += torch.sum(dp * dp) self.errors += torch.sum(dp * dp)
self.errors *= self.hparams.step_reduction self.errors *= self.hparams["step_reduction"]
self.topology_layer(d) self.topology_layer(d)
self.log("loss", loss) self.log("loss", loss)
@ -147,7 +148,7 @@ class GrowingNeuralGas(NeuralGas):
def configure_callbacks(self): def configure_callbacks(self):
return [ return [
GNGCallback( GNGCallback(
reduction=self.hparams.insert_reduction, reduction=self.hparams["insert_reduction"],
freq=self.hparams.insert_freq, freq=self.hparams["insert_freq"],
) )
] ]

View File

@ -1,5 +1,8 @@
"""Visualization Callbacks.""" """Visualization Callbacks."""
import warnings
from typing import Sized
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
@ -7,6 +10,7 @@ import torchvision
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.utils.colors import get_colors, get_legend_handles from prototorch.utils.colors import get_colors, get_legend_handles
from prototorch.utils.utils import mesh2d from prototorch.utils.utils import mesh2d
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
@ -33,8 +37,13 @@ class Vis2DAbstract(pl.Callback):
if data: if data:
if isinstance(data, Dataset): if isinstance(data, Dataset):
if isinstance(data, Sized):
x, y = next(iter(DataLoader(data, batch_size=len(data)))) x, y = next(iter(DataLoader(data, batch_size=len(data))))
elif isinstance(data, torch.utils.data.DataLoader): else:
# TODO: Add support for non-sized datasets
raise NotImplementedError(
"Data must be a dataset with a __len__ method.")
elif isinstance(data, DataLoader):
x = torch.tensor([]) x = torch.tensor([])
y = torch.tensor([]) y = torch.tensor([])
for x_b, y_b in data: for x_b, y_b in data:
@ -122,7 +131,7 @@ class Vis2DAbstract(pl.Callback):
else: else:
plt.show(block=self.block) plt.show(block=self.block)
def on_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer): if not self.precheck(trainer):
return True return True
self.visualize(pl_module) self.visualize(pl_module)
@ -131,6 +140,9 @@ class Vis2DAbstract(pl.Callback):
def on_train_end(self, trainer, pl_module): def on_train_end(self, trainer, pl_module):
plt.close() plt.close()
def visualize(self, pl_module):
raise NotImplementedError
class VisGLVQ2D(Vis2DAbstract): class VisGLVQ2D(Vis2DAbstract):
@ -291,9 +303,13 @@ class VisImgComp(Vis2DAbstract):
self.add_embedding = add_embedding self.add_embedding = add_embedding
self.embedding_data = embedding_data self.embedding_data = embedding_data
def on_train_start(self, trainer, pl_module): def on_train_start(self, _, pl_module):
if isinstance(pl_module.logger, TensorBoardLogger):
tb = pl_module.logger.experiment tb = pl_module.logger.experiment
# Add embedding
if self.add_embedding: if self.add_embedding:
if self.x_train is not None and self.y_train is not None:
ind = np.random.choice(len(self.x_train), ind = np.random.choice(len(self.x_train),
size=self.embedding_data, size=self.embedding_data,
replace=False) replace=False)
@ -304,17 +320,28 @@ class VisImgComp(Vis2DAbstract):
tag="Data Embedding", tag="Data Embedding",
metadata=self.y_train[ind], metadata=self.y_train[ind],
metadata_header=None) metadata_header=None)
else:
raise ValueError("No data for add embedding flag")
# Random Data
if self.random_data: if self.random_data:
if self.x_train is not None:
ind = np.random.choice(len(self.x_train), ind = np.random.choice(len(self.x_train),
size=self.random_data, size=self.random_data,
replace=False) replace=False)
data = self.x_train[ind] data = self.x_train[ind]
grid = torchvision.utils.make_grid(data, nrow=self.num_columns) grid = torchvision.utils.make_grid(data,
nrow=self.num_columns)
tb.add_image(tag="Data", tb.add_image(tag="Data",
img_tensor=grid, img_tensor=grid,
global_step=None, global_step=None,
dataformats=self.dataformats) dataformats=self.dataformats)
else:
raise ValueError("No data for random data flag")
else:
warnings.warn(
f"TensorBoardLogger is required, got {type(pl_module.logger)}")
def add_to_tensorboard(self, trainer, pl_module): def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment tb = pl_module.logger.experiment