13 Commits

Author SHA1 Message Date
Jensun Ravichandran
aeb6417c28 refactor: minor changes in probabilistic.py 2021-08-06 13:49:29 +02:00
Jensun Ravichandran
cb7fb91c95 feat: add binnam_xor.py 2021-07-15 18:19:28 +02:00
Jensun Ravichandran
823b05e390 feat: add neural additive model for binary classification 2021-07-14 20:07:34 +02:00
Jensun Ravichandran
f8ad1d83eb refactor: clean up abstract classes 2021-07-14 19:17:05 +02:00
Jensun Ravichandran
23a3683860 fix(doc): update outdated 2021-07-12 21:21:29 +02:00
Jensun Ravichandran
4be9fb81eb feat(model): implement MedianLVQ 2021-07-06 17:12:51 +02:00
Jensun Ravichandran
9d38123114 refactor: use GLVQLoss instead of LossLayer 2021-07-06 17:09:21 +02:00
Jensun Ravichandran
0f9f24e36a feat: add early-stopping and pruning to examples/warm_starting.py 2021-06-30 16:04:26 +02:00
Jensun Ravichandran
09e3ef1d0e fix: remove deprecated Trainer.accelerator_backend 2021-06-30 16:03:45 +02:00
Alexander Engelsberger
7b9b767113 fix: training loss is a zero dimensional tensor
Should fix the problem with EarlyStopping callback.
2021-06-25 17:07:06 +02:00
Jensun Ravichandran
f56ec44afe chore(github): update bug report issue template 2021-06-25 17:07:06 +02:00
Jensun Ravichandran
67a20124e8 chore(github): add issue templates 2021-06-25 17:07:06 +02:00
Jensun Ravichandran
72af03b991 refactor: use LinearTransform instead of torch.nn.Linear 2021-06-25 17:07:06 +02:00
21 changed files with 622 additions and 413 deletions

38
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,38 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Steps to reproduce the behavior**
1. ...
2. Run script '...' or this snippet:
```python
import prototorch as pt
...
```
3. See errors
**Expected behavior**
A clear and concise description of what you expected to happen.
**Observed behavior**
A clear and concise description of what actually happened.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**System and version information**
- OS: [e.g. Ubuntu 20.10]
- ProtoTorch Version: [e.g. 0.4.0]
- Python Version: [e.g. 3.9.5]
**Additional context**
Add any other context about the problem here.

View File

@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@@ -36,6 +36,7 @@ be available for use in your Python environment as `prototorch.models`.
- Soft Learning Vector Quantization (SLVQ)
- Robust Soft Learning Vector Quantization (RSLVQ)
- Probabilistic Learning Vector Quantization (PLVQ)
- Median-LVQ
### Other
@@ -51,7 +52,6 @@ be available for use in your Python environment as `prototorch.models`.
## Planned models
- Median-LVQ
- Generalized Tangent Learning Vector Quantization (GTLVQ)
- Self-Incremental Learning Vector Quantization (SILVQ)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,81 @@
"""Neural Additive Model (NAM) example for binary classification."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Tecator("~/datasets")
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(lr=0.1)
# Define the feature extractor
class FE(torch.nn.Module):
def __init__(self):
super().__init__()
self.modules_list = torch.nn.ModuleList([
torch.nn.Linear(1, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
])
def forward(self, x):
for m in self.modules_list:
x = m(x)
return x
# Initialize the model
model = pt.models.BinaryNAM(
hparams,
extractors=torch.nn.ModuleList([FE() for _ in range(100)]),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 100)
# Callbacks
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=20,
mode="min",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
es,
],
terminate_on_nan=True,
weights_summary=None,
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)
# Visualize extractor shape functions
fig, axes = plt.subplots(10, 10)
for i, ax in enumerate(axes.flat):
x = torch.linspace(-2, 2, 100) # TODO use min/max from data
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
ax.plot(x, y)
ax.set(title=f"Feature {i + 1}", xticklabels=[], yticklabels=[])
plt.show()

86
examples/binnam_xor.py Normal file
View File

@@ -0,0 +1,86 @@
"""Neural Additive Model (NAM) example for binary classification."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.XOR()
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
# Hyperparameters
hparams = dict(lr=0.001)
# Define the feature extractor
class FE(torch.nn.Module):
def __init__(self, hidden_size=10):
super().__init__()
self.modules_list = torch.nn.ModuleList([
torch.nn.Linear(1, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, 1),
torch.nn.ReLU(),
])
def forward(self, x):
for m in self.modules_list:
x = m(x)
return x
# Initialize the model
model = pt.models.BinaryNAM(
hparams,
extractors=torch.nn.ModuleList([FE(20) for _ in range(2)]),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks
vis = pt.models.Vis2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=50,
mode="min",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
terminate_on_nan=True,
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)
# Visualize extractor shape functions
fig, axes = plt.subplots(2)
for i, ax in enumerate(axes.flat):
x = torch.linspace(0, 1, 100) # TODO use min/max from data
y = model.extractors[i](x.view(100, 1)).squeeze().detach()
ax.plot(x, y)
ax.set(title=f"Feature {i + 1}")
plt.show()

View File

@@ -1,12 +1,11 @@
"""GMLVQ example using the MNIST dataset."""
import torch
from pytorch_lightning.utilities.cli import LightningCLI
import prototorch as pt
import torch
from prototorch.models import ImageGMLVQ
from prototorch.models.abstract import PrototypeModel
from prototorch.models.data import MNISTDataModule
from pytorch_lightning.utilities.cli import LightningCLI
class ExperimentClass(ImageGMLVQ):

View File

@@ -66,7 +66,7 @@ if __name__ == "__main__":
args,
callbacks=[
vis,
# es, # FIXME
es,
pruning,
],
terminate_on_nan=True,

View File

@@ -2,12 +2,11 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from sklearn.datasets import load_iris
import prototorch as pt
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()

View File

@@ -0,0 +1,52 @@
"""Median-LVQ example using the Iris dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = torch.utils.data.DataLoader(
train_ds,
batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches
)
# Initialize the model
model = pt.models.MedianLVQ(
hparams=dict(distribution=(3, 2), lr=0.01),
prototypes_initializer=pt.initializers.SSCI(train_ds),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_acc",
min_delta=0.01,
patience=5,
mode="max",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis, es],
weights_summary="full",
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -37,7 +37,7 @@ if __name__ == "__main__":
# Setup trainer for GNG
trainer = pl.Trainer(
max_epochs=200,
max_epochs=100,
callbacks=[es],
weights_summary=None,
)
@@ -71,11 +71,30 @@ if __name__ == "__main__":
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
pruning = pt.models.PruneLoserPrototypes(
threshold=0.02,
idle_epochs=2,
prune_quota_per_epoch=5,
frequency=1,
verbose=True,
)
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=10,
mode="min",
verbose=True,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
callbacks=[
vis,
pruning,
es,
],
weights_summary="full",
accelerator="ddp",
)

View File

@@ -19,6 +19,7 @@ from .glvq import (
)
from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ
from .nam import BinaryNAM
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
from .vis import *

View File

@@ -14,20 +14,8 @@ from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
class ProtoTorchMixin(object):
pass
class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts."""
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
wrapped = f"ProtoTorch Bolt(\n{indented})"
return wrapped
class PrototypeModel(ProtoTorchBolt):
def __init__(self, hparams, **kwargs):
super().__init__()
@@ -42,22 +30,6 @@ class PrototypeModel(ProtoTorchBolt):
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
@property
def num_prototypes(self):
return len(self.proto_layer.components)
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
if self.lr_scheduler is not None:
@@ -73,7 +45,34 @@ class PrototypeModel(ProtoTorchBolt):
@final
def reconfigure_optimizers(self):
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
self.trainer.accelerator.setup_optimizers(self.trainer)
def __repr__(self):
surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
wrapped = f"ProtoTorch Bolt(\n{indented})"
return wrapped
class PrototypeModel(ProtoTorchBolt):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
self.distance_layer = LambdaLayer(distance_fn)
@property
def num_prototypes(self):
return len(self.proto_layer.components)
@property
def prototypes(self):
return self.proto_layer.components.detach().cpu()
@property
def components(self):
"""Only an alias for the prototypes."""
return self.prototypes
def add_prototypes(self, *args, **kwargs):
self.proto_layer.add_components(*args, **kwargs)
@@ -167,6 +166,11 @@ class SupervisedPrototypeModel(PrototypeModel):
logger=True)
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
pass
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):

View File

@@ -48,7 +48,7 @@ class CBC(SiameseGLVQ):
y_pred = self(x)
num_classes = self.num_classes
y_true = torch.nn.functional.one_hot(y.long(), num_classes=num_classes)
loss = self.loss(y_pred, y_true).mean(dim=0)
loss = self.loss(y_pred, y_true).mean()
return y_pred, loss
def training_step(self, batch, batch_idx, optimizer_idx=None):

View File

@@ -5,13 +5,12 @@ Mainly used for PytorchLightningCLI configurations.
"""
from typing import Any, Optional, Type
import prototorch as pt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import prototorch as pt
# MNIST
class MNISTDataModule(pl.LightningDataModule):

View File

@@ -6,8 +6,8 @@ from torch.nn.parameter import Parameter
from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from ..core.initializers import EyeTransformInitializer
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
from ..nn.activations import get_activation
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
from ..core.transforms import LinearTransform
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
@@ -18,15 +18,16 @@ class GLVQ(SupervisedPrototypeModel):
super().__init__(hparams, **kwargs)
# Default hparams
self.hparams.setdefault("margin", 0.0)
self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
# Layers
transfer_fn = get_activation(self.hparams.transfer_fn)
self.transfer_layer = LambdaLayer(transfer_fn)
# Loss
self.loss = LossLayer(glvq_loss)
self.loss = GLVQLoss(
margin=self.hparams.margin,
transfer_fn=self.hparams.transfer_fn,
beta=self.hparams.transfer_beta,
)
def initialize_prototype_win_ratios(self):
self.register_buffer(
@@ -55,9 +56,7 @@ class GLVQ(SupervisedPrototypeModel):
x, y = batch
out = self.compute_distances(x)
plabels = self.proto_layer.labels
mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
loss = self.loss(out, y, plabels)
return out, loss
def training_step(self, batch, batch_idx, optimizer_idx=None):
@@ -208,18 +207,22 @@ class SiameseGMLVQ(SiameseGLVQ):
super().__init__(hparams, **kwargs)
# Override the backbone
self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim,
bias=False)
omega_initializer = kwargs.get("omega_initializer",
EyeTransformInitializer())
self.backbone = LinearTransform(
self.hparams.input_dim,
self.hparams.output_dim,
initializer=omega_initializer,
)
@property
def omega_matrix(self):
return self.backbone.weight.detach().cpu()
return self.backbone.weights
@property
def lambda_matrix(self):
omega = self.backbone.weight # (latent_dim, input_dim)
lam = omega.T @ omega
omega = self.backbone.weight # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()

View File

@@ -1,6 +1,8 @@
"""LVQ models that are optimized using non-gradient methods."""
from ..core.losses import _get_dp_dm
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin
from .glvq import GLVQ
@@ -66,4 +68,61 @@ class LVQ21(NonGradientMixin, GLVQ):
class MedianLVQ(NonGradientMixin, GLVQ):
"""Median LVQ"""
"""Median LVQ
# TODO Avoid computing distances over and over
"""
def __init__(self, hparams, verbose=True, **kwargs):
self.verbose = verbose
super().__init__(hparams, **kwargs)
self.transfer_layer = LambdaLayer(
get_activation(self.hparams.transfer_fn))
def _f(self, x, y, protos, plabels):
d = self.distance_layer(x, protos)
dp, dm = _get_dp_dm(d, y, plabels)
mu = (dp - dm) / (dp + dm)
invmu = -1.0 * mu
f = self.transfer_layer(invmu, beta=self.hparams.transfer_beta) + 1.0
return f
def expectation(self, x, y, protos, plabels):
f = self._f(x, y, protos, plabels)
gamma = f / f.sum()
return gamma
def lower_bound(self, x, y, protos, plabels, gamma):
f = self._f(x, y, protos, plabels)
lower_bound = (gamma * f.log()).sum()
return lower_bound
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
x, y = train_batch
dis = self.compute_distances(x)
for i, _ in enumerate(protos):
# Expectation step
gamma = self.expectation(x, y, protos, plabels)
lower_bound = self.lower_bound(x, y, protos, plabels, gamma)
# Maximization step
_protos = protos + 0
for k, xk in enumerate(x):
_protos[i] = xk
_lower_bound = self.lower_bound(x, y, _protos, plabels, gamma)
if _lower_bound > lower_bound:
if self.verbose:
print(f"Updating prototype {i} to data {k}...")
self.proto_layer.load_state_dict({"_components": _protos},
strict=False)
break
# Logging
self.log_acc(dis, y, tag="train_acc")
return None

58
prototorch/models/nam.py Normal file
View File

@@ -0,0 +1,58 @@
"""ProtoTorch Neural Additive Model."""
import torch
import torchmetrics
from .abstract import ProtoTorchBolt
class BinaryNAM(ProtoTorchBolt):
"""Neural Additive Model for binary classification.
Paper: https://arxiv.org/abs/2004.13912
Official implementation: https://github.com/google-research/google-research/tree/master/neural_additive_models
"""
def __init__(self, hparams: dict, extractors: torch.nn.ModuleList,
**kwargs):
super().__init__(hparams, **kwargs)
# Default hparams
self.hparams.setdefault("threshold", 0.5)
self.extractors = extractors
self.linear = torch.nn.Linear(in_features=len(extractors),
out_features=1,
bias=True)
def extract(self, x):
"""Apply the local extractors batch-wise on features."""
out = torch.zeros_like(x)
for j in range(x.shape[1]):
out[:, j] = self.extractors[j](x[:, j].unsqueeze(1)).squeeze()
return out
def forward(self, x):
x = self.extract(x)
x = self.linear(x)
return torch.sigmoid(x)
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
preds = self(x).squeeze()
train_loss = torch.nn.functional.binary_cross_entropy(preds, y.float())
self.log("train_loss", train_loss)
accuracy = torchmetrics.functional.accuracy(preds.int(), y.int())
self.log("train_acc",
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
return train_loss
def predict(self, x):
out = self(x)
pred = torch.zeros_like(out, device=self.device)
pred[out > self.hparams.threshold] = 1
return pred

View File

@@ -1,5 +1,4 @@
"""Probabilistic GLVQ methods"""
import torch
from ..core.losses import nllr_loss, rslvq_loss
@@ -24,7 +23,7 @@ class CELVQ(GLVQ):
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning
batch_loss = self.loss(probs, y.long())
loss = batch_loss.sum(dim=0)
loss = batch_loss.sum()
return out, loss
@@ -32,7 +31,7 @@ class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs)
self.conditional_distribution = None
self.conditional_distribution = GaussianPrior(self.hparams.variance)
self.rejection_confidence = rejection_confidence
def forward(self, x):
@@ -56,8 +55,9 @@ class ProbabilisticLVQ(GLVQ):
out = self.forward(x)
plabels = self.proto_layer.labels
batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum(dim=0)
return loss
train_loss = batch_loss.sum()
self.log("train_loss", train_loss)
return train_loss
class SLVQ(ProbabilisticLVQ):
@@ -65,7 +65,6 @@ class SLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss = LossLayer(nllr_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class RSLVQ(ProbabilisticLVQ):
@@ -73,7 +72,6 @@ class RSLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss = LossLayer(rslvq_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
@@ -92,5 +90,5 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
# x, y = batch
# y_pred = self(x)
# batch_loss = self.loss(y_pred, y)
# loss = batch_loss.sum(dim=0)
# loss = batch_loss.sum()
# return loss

View File

@@ -132,7 +132,7 @@ class GrowingNeuralGas(NeuralGas):
mask[torch.arange(len(mask)), winner] = 1.0
dp = d * mask
self.errors += torch.sum(dp * dp, dim=0)
self.errors += torch.sum(dp * dp)
self.errors *= self.hparams.step_reduction
self.topology_layer(d)

View File

@@ -117,6 +117,24 @@ class Vis2DAbstract(pl.Callback):
plt.close()
class Vis2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
mesh_input = torch.from_numpy(mesh_input).type_as(x_train)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):