16 Commits

Author SHA1 Message Date
Alexander Engelsberger
5ce326ce62 feat: CLCC register torchmetrics added 2021-10-15 15:18:02 +02:00
Alexander Engelsberger
d1985571b3 feat: Improve 2D visualization with Voronoi Cells 2021-10-15 13:01:01 +02:00
Alexander Engelsberger
967953442b feat: Add basic GLVQ with new architecture 2021-10-14 15:49:12 +02:00
Alexander Engelsberger
d4448f2bc9 chore(pre-commit): Update plugin versions and rerun all files 2021-10-13 10:54:53 +02:00
Alexander Engelsberger
a8829945f5 chore: Move mixins into seperate file 2021-10-11 16:05:12 +02:00
Alexander Engelsberger
a8336ee213 chore: Remove relative imports 2021-10-11 15:45:43 +02:00
Alexander Engelsberger
6ffd27d12a chore: Remove PytorchLightning CLI related code
Could be moved in a seperate plugin.
2021-10-11 15:16:12 +02:00
Alexander Engelsberger
859e2cae69 docs(dependencies): Add missing ipykernel dependency for docs 2021-10-11 15:11:53 +02:00
Alexander Engelsberger
d7ea89d47e feat: add simple test step 2021-09-10 19:19:51 +02:00
Jensun Ravichandran
fa928afe2c feat(vis): 2D EV projection for GMLVQ 2021-09-01 10:49:57 +02:00
Alexander Engelsberger
7d4a041df2 build: bump version 0.2.0 → 0.3.0 2021-08-30 20:50:03 +02:00
Alexander Engelsberger
04c51c00c6 ci: seperate build step 2021-08-30 20:44:16 +02:00
Alexander Engelsberger
62185b38cf chore: Update prototorch dependency 2021-08-30 20:32:47 +02:00
Alexander Engelsberger
7b93cd4ad5 feat(compatibility): Python3.6 compatibility 2021-08-30 20:32:40 +02:00
Alexander Engelsberger
d7834e2cc0 fix: All examples should work on CPU and GPU now 2021-08-05 11:20:02 +02:00
Alexander Engelsberger
0af8cf36f8 fix: labels where on cpu in forward pass 2021-08-05 09:14:32 +02:00
30 changed files with 702 additions and 335 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.2.0
current_version = 0.3.0
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View File

@@ -18,12 +18,12 @@ repos:
- id: autoflake
- repo: http://github.com/PyCQA/isort
rev: 5.8.0
rev: 5.9.3
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.902
rev: v0.910-1
hooks:
- id: mypy
files: prototorch
@@ -42,9 +42,10 @@ repos:
- id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v2.19.4
rev: v2.29.0
hooks:
- id: pyupgrade
args: [--py36-plus]
- repo: https://github.com/si-cim/gitlint
rev: v0.15.2-unofficial

View File

@@ -1,7 +1,11 @@
dist: bionic
sudo: false
language: python
python: 3.9
python:
- 3.9
- 3.8
- 3.7
- 3.6
cache:
directories:
- "$HOME/.cache/pip"
@@ -15,11 +19,26 @@ script:
- ./tests/test_examples.sh examples/
after_success:
- bash <(curl -s https://codecov.io/bash)
deploy:
provider: pypi
username: __token__
password:
secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
on:
tags: true
skip_existing: true
# Publish on PyPI
jobs:
include:
- stage: build
python: 3.9
script: echo "Starting Pypi build"
deploy:
provider: pypi
username: __token__
distributions: "sdist bdist_wheel"
password:
secure: PDoASdYdVlt1aIROYilAsCW6XpBs/TDel0CSptDzX0CI7i4+ksEW6Jk0JyL58bQt7V4F8PeGty4A8SODzAUIk2d8sty5RI4VJjvXZFCXlUsW+JGUN3EvWNqJLnwN8TDxgu2ENao37GUh0dC6pL8b6bVDGeOLaY1E/YR1jimmTJuxxjKjBIU8ByqTNBnC3rzybMTPU3nRoOM/WMQUyReHrPoUJj685sLqrLruhAqhiYsPbotP8xY6i8+KBbhp5vgiARV2+LkbeGcYZwozCzrEqPKY7YIfVPh895cw0v4NRyFwK1P2jyyIt22Z9Ni0Uy1J5/Qp9Sv6mBPeGjm3pnpDCQyS+2bNIDaj08KUYTIo1mC/Jcu4jQgppZEF+oey9q1tgGo+/JhsTeERKV9BoPF5HDiRArU1s5aWJjFnCsHfu+W1XqX8bwN3aTYsEIaApT3/irc6XyFJIfMN82+z+lUcZ4Y1yAHT3nH1Vif+pZYZB0UOSGrHwuI/UayjKzbCzHMuHWylWB/9ehd4o4YVp6iubVHc7Sj0KQkwBgwgl6TvwNcUuFsplFabCxmX0mVcavXsWiOBc+ivPmU6574zGj0JcEk5ghVgnKH+QS96aVrKOzegwbl4O13jY8dJp+/zgXl0gJOvRKr4BhuBJKcBaMQHdSKUChVsJJtqDyt59GvWcbg=
on:
tags: true
skip_existing: true
# The password is encrypted with:
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
# See https://docs.travis-ci.com/user/deployment/pypi and
# https://github.com/travis-ci/travis.rb#installation
# for more details
# Note: The encrypt command does not work well in ZSH.

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "0.2.0"
release = "0.3.0"
# -- General configuration ---------------------------------------------------

View File

@@ -38,10 +38,12 @@ if __name__ == "__main__":
)
# Callbacks
vis = pt.models.VisCBC2D(data=train_ds,
title="CBC Iris Example",
resolution=100,
axis_off=True)
vis = pt.models.Visualize2DVoronoiCallback(
data=train_ds,
title="CBC Iris Example",
resolution=100,
axis_off=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(

View File

@@ -1,8 +0,0 @@
# Examples using Lightning CLI
Examples in this folder use the experimental [Lightning CLI](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_cli.html).
To use the example run
```
python gmlvq.py --config gmlvq.yaml
```

View File

@@ -1,19 +0,0 @@
"""GMLVQ example using the MNIST dataset."""
import prototorch as pt
import torch
from prototorch.models import ImageGMLVQ
from prototorch.models.abstract import PrototypeModel
from prototorch.models.data import MNISTDataModule
from pytorch_lightning.utilities.cli import LightningCLI
class ExperimentClass(ImageGMLVQ):
def __init__(self, hparams, **kwargs):
super().__init__(hparams,
optimizer=torch.optim.Adam,
prototype_initializer=pt.components.zeros(28 * 28),
**kwargs)
cli = LightningCLI(ImageGMLVQ, MNISTDataModule)

View File

@@ -1,11 +0,0 @@
model:
hparams:
input_dim: 784
latent_dim: 784
distribution:
num_classes: 10
prototypes_per_class: 2
proto_lr: 0.01
bb_lr: 0.01
data:
batch_size: 32

View File

@@ -3,6 +3,7 @@
import argparse
import prototorch as pt
import prototorch.models.clcc
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
@@ -29,7 +30,7 @@ if __name__ == "__main__":
)
# Initialize the model
model = pt.models.GLVQ(
model = prototorch.models.GLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
@@ -41,7 +42,13 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
vis = pt.models.Visualize2DVoronoiCallback(
data=train_ds,
resolution=200,
title="Example: GLVQ on Iris",
x_label="sepal length",
y_label="petal length",
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(

58
examples/gmlvq_iris.py Normal file
View File

@@ -0,0 +1,58 @@
"""GMLVQ example using the Iris dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
input_dim=4,
latent_dim=4,
distribution={
"num_classes": 3,
"per_class": 2
},
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = pt.models.GMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 4)
# Callbacks
vis = pt.models.VisGMLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -6,6 +6,7 @@ import prototorch as pt
import pytorch_lightning as pl
import torch
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
if __name__ == "__main__":
# Command-line arguments
@@ -14,12 +15,20 @@ if __name__ == "__main__":
args = parser.parse_args()
# Dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
X, y = load_iris(return_X_y=True)
X = X[:, [0, 2]]
X_train, X_test, y_train, y_test = train_test_split(X,
y,
test_size=0.5,
random_state=42)
train_ds = pt.datasets.NumpyDataset(X_train, y_train)
test_ds = pt.datasets.NumpyDataset(X_test, y_test)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=16)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=16)
# Hyperparameters
hparams = dict(k=5)
@@ -35,7 +44,7 @@ if __name__ == "__main__":
# Callbacks
vis = pt.models.VisGLVQ2D(
data=(x_train, y_train),
data=(X_train, y_train),
resolution=200,
block=True,
)
@@ -53,5 +62,8 @@ if __name__ == "__main__":
trainer.fit(model, train_loader)
# Recall
y_pred = model.predict(torch.tensor(x_train))
y_pred = model.predict(torch.tensor(X_train))
print(y_pred)
# Test
trainer.test(model, dataloaders=test_loader)

View File

@@ -1,7 +1,5 @@
"""`models` plugin for the `prototorch` package."""
from importlib.metadata import PackageNotFoundError, version
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC
from .glvq import (
@@ -23,4 +21,4 @@ from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
from .vis import *
__version__ = "0.2.0"
__version__ = "0.3.0"

View File

@@ -1,17 +1,14 @@
"""Abstract classes to be inherited by prototorch models."""
from typing import Final, final
import pytorch_lightning as pl
import torch
import torchmetrics
from ..core.competitions import WTAC
from ..core.components import Components, LabeledComponents
from ..core.distances import euclidean_distance
from ..core.initializers import LabelsInitializer
from ..core.pooling import stratified_min_pooling
from ..nn.wrappers import LambdaLayer
from prototorch.core.competitions import WTAC
from prototorch.core.components import Components, LabeledComponents
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import LabelsInitializer
from prototorch.core.pooling import stratified_min_pooling
from prototorch.nn.wrappers import LambdaLayer
class ProtoTorchBolt(pl.LightningModule):
@@ -43,7 +40,6 @@ class ProtoTorchBolt(pl.LightningModule):
else:
return optimizer
@final
def reconfigure_optimizers(self):
self.trainer.accelerator.setup_optimizers(self.trainer)
@@ -96,7 +92,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
)
def compute_distances(self, x):
protos = self.proto_layer()
protos = self.proto_layer().type_as(x)
distances = self.distance_layer(x, protos)
return distances
@@ -136,14 +132,14 @@ class SupervisedPrototypeModel(PrototypeModel):
def forward(self, x):
distances = self.compute_distances(x)
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
winning = stratified_min_pooling(distances, plabels)
y_pred = torch.nn.functional.softmin(winning)
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
y_pred = self.competition_layer(distances, plabels)
return y_pred
@@ -165,32 +161,10 @@ class SupervisedPrototypeModel(PrototypeModel):
prog_bar=True,
logger=True)
def test_step(self, batch, batch_idx):
x, targets = batch
class ProtoTorchMixin(object):
"""All mixins are ProtoTorchMixins."""
pass
preds = self.predict(x)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization: Final = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
@final
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()
self.log("test_acc", accuracy)

View File

@@ -4,9 +4,9 @@ import logging
import pytorch_lightning as pl
import torch
from prototorch.core.components import Components
from prototorch.core.initializers import LiteralCompInitializer
from ..core.components import Components
from ..core.initializers import LiteralCompInitializer
from .extras import ConnectionTopology
@@ -55,7 +55,7 @@ class PruneLoserPrototypes(pl.Callback):
distribution = dict(zip(labels.tolist(), counts.tolist()))
if self.verbose:
print(f"Re-adding pruned prototypes...")
print(f"{distribution=}")
print(f"distribution={distribution}")
pl_module.add_prototypes(
distribution=distribution,
components_initializer=self.prototypes_initializer)
@@ -134,4 +134,4 @@ class GNGCallback(pl.Callback):
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator_backend.setup_optimizers(trainer)
trainer.accelerator.setup_optimizers(trainer)

View File

@@ -1,14 +1,14 @@
import torch
import torchmetrics
from prototorch.core.competitions import CBCC
from prototorch.core.components import ReasoningComponents
from prototorch.core.initializers import RandomReasoningsInitializer
from prototorch.core.losses import MarginLoss
from prototorch.core.similarities import euclidean_similarity
from prototorch.nn.wrappers import LambdaLayer
from ..core.competitions import CBCC
from ..core.components import ReasoningComponents
from ..core.initializers import RandomReasoningsInitializer
from ..core.losses import MarginLoss
from ..core.similarities import euclidean_similarity
from ..nn.wrappers import LambdaLayer
from .abstract import ImagePrototypesMixin
from .glvq import SiameseGLVQ
from .mixin import ImagePrototypesMixin
class CBC(SiameseGLVQ):

View File

View File

@@ -0,0 +1,86 @@
from dataclasses import dataclass
from typing import Callable
import torch
from prototorch.core.competitions import WTAC
from prototorch.core.components import LabeledComponents
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import AbstractComponentsInitializer, LabelsInitializer
from prototorch.core.losses import GLVQLoss
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.nn.wrappers import LambdaLayer
@dataclass
class GLVQhparams:
distribution: dict
component_initializer: AbstractComponentsInitializer
distance_fn: Callable = euclidean_distance
lr: float = 0.01
margin: float = 0.0
# TODO: make nicer
transfer_fn: str = "identity"
transfer_beta: float = 10.0
optimizer: torch.optim.Optimizer = torch.optim.Adam
class GLVQ(CLCCScheme):
def __init__(self, hparams: GLVQhparams) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Initializers
def init_components(self, hparams):
# initialize Component Layer
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
def init_comparison(self, hparams):
# initialize Distance Layer
self.comparison_layer = LambdaLayer(hparams.distance_fn)
def init_inference(self, hparams):
self.competition_layer = WTAC()
def init_loss(self, hparams):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
beta=hparams.transfer_beta,
)
# Steps
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(batch_tensor, comp_tensor)
return distances
def inference(self, comparisonmeasures, components):
comp_labels = components[1]
return self.competition_layer(comparisonmeasures, comp_labels)
def loss(self, comparisonmeasures, batch, components):
target = batch[1]
comp_labels = components[1]
return self.loss_layer(comparisonmeasures, target, comp_labels)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr)
# Properties
@property
def prototypes(self):
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
return self.components_layer.labels.detach().cpu()

View File

@@ -0,0 +1,192 @@
"""
CLCC Scheme
CLCC is a LVQ scheme containing 4 steps
- Components
- Latent Space
- Comparison
- Competition
"""
from typing import Dict, Set, Type
import pytorch_lightning as pl
import torch
import torchmetrics
class CLCCScheme(pl.LightningModule):
registered_metrics: Dict[Type[torchmetrics.Metric],
torchmetrics.Metric] = {}
registered_metric_names: Dict[Type[torchmetrics.Metric], Set[str]] = {}
def __init__(self, hparams) -> None:
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# Initialize Model Metrics
self.init_model_metrics()
# internal API, called by models and callbacks
def register_torchmetric(self, name: str, metric: torchmetrics.Metric):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric()
self.registered_metric_names[metric] = {name}
else:
self.registered_metric_names[metric].add(name)
# external API
def get_competion(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competion(batch, components)
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competion(batch, components)
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Type hints
# TODO: Docs
def init_components(self, hparams):
...
def init_latent(self, hparams):
...
def init_comparison(self, hparams):
...
def init_competition(self, hparams):
...
def init_loss(self, hparams):
...
def init_inference(self, hparams):
...
def init_model_metrics(self):
self.register_torchmetric('train_accuracy', torchmetrics.Accuracy)
# Empty Steps
# TODO: Type hints
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def latent(self, batch, components):
"""
The latent step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
Takes a batch of size N and the componentsset of size M.
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
def competition(self, comparisonmeasures, components):
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
def loss(self, comparisonmeasures, batch, components):
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
def inference(self, comparisonmeasures, components):
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
def update_metrics_step(self, batch):
x, y = batch
preds = self(x)
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance(y, preds)
for name in self.registered_metric_names[metric]:
self.log(name, value)
def update_metrics_epoch(self):
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
value = instance.compute()
for name in self.registered_metric_names[metric]:
self.log(name, value)
# Lightning Hooks
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch)
return self.loss_forward(batch)
def train_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)

View File

@@ -0,0 +1,76 @@
from typing import Optional
import matplotlib.pyplot as plt
import prototorch as pt
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.core.initializers import SMCI, RandomNormalCompInitializer
from prototorch.models.clcc.clcc_glvq import GLVQ, GLVQhparams
from prototorch.models.clcc.clcc_scheme import CLCCScheme
from prototorch.models.vis import Visualize2DVoronoiCallback
# NEW STUFF
# ##############################################################################
# TODO: Metrics
class MetricsTestCallback(pl.Callback):
metric_name = "test_cb_acc"
def setup(self,
trainer: pl.Trainer,
pl_module: CLCCScheme,
stage: Optional[str] = None) -> None:
pl_module.register_torchmetric(self.metric_name, torchmetrics.Accuracy)
def on_epoch_end(self, trainer: pl.Trainer,
pl_module: pl.LightningModule) -> None:
metric = trainer.logged_metrics[self.metric_name]
if metric > 0.95:
trainer.should_stop = True
# TODO: Pruning
# ##############################################################################
if __name__ == "__main__":
# Dataset
train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=64,
num_workers=8)
components_initializer = SMCI(train_ds)
hparams = GLVQhparams(
distribution=dict(
num_classes=3,
per_class=2,
),
component_initializer=components_initializer,
)
model = GLVQ(hparams)
print(model)
# Callbacks
vis = Visualize2DVoronoiCallback(
data=train_ds,
resolution=500,
)
metrics = MetricsTestCallback()
# Train
trainer = pl.Trainer(
callbacks=[
#vis,
metrics,
],
gpus=1,
max_epochs=100,
weights_summary=None,
log_every_n_steps=1,
)
trainer.fit(model, train_loader)

View File

@@ -1,123 +0,0 @@
"""Prototorch Data Modules
This allows to store the used dataset inside a Lightning Module.
Mainly used for PytorchLightningCLI configurations.
"""
from typing import Any, Optional, Type
import prototorch as pt
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
# MNIST
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
# Download mnist dataset as side-effect, only called on the first cpu
def prepare_data(self):
MNIST("~/datasets", train=True, download=True)
MNIST("~/datasets", train=False, download=True)
# called for every GPU/machine (assigning state is OK)
def setup(self, stage=None):
# Transforms
transform = transforms.Compose([
transforms.ToTensor(),
])
# Split dataset
if stage in (None, "fit"):
mnist_train = MNIST("~/datasets", train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(
mnist_train,
[55000, 5000],
)
if stage == (None, "test"):
self.mnist_test = MNIST(
"~/datasets",
train=False,
transform=transform,
)
# Dataloaders
def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return mnist_train
def val_dataloader(self):
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
return mnist_val
def test_dataloader(self):
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
return mnist_test
# def train_on_mnist(batch_size=256) -> type:
# class DataClass(pl.LightningModule):
# datamodule = MNISTDataModule(batch_size=batch_size)
# def __init__(self, *args, **kwargs):
# prototype_initializer = kwargs.pop(
# "prototype_initializer", pt.components.Zeros((28, 28, 1)))
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# dc: Type[DataClass] = DataClass
# return dc
# ABSTRACT
class GeneralDataModule(pl.LightningDataModule):
def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
super().__init__()
self.train_dataset = dataset
self.batch_size = batch_size
def train_dataloader(self) -> DataLoader:
return DataLoader(self.train_dataset, batch_size=self.batch_size)
# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
# class DataClass(pl.LightningModule):
# datamodule = GeneralDataModule(dataset, batch_size)
# datashape = dataset[0][0].shape
# example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
# def __init__(self, *args: Any, **kwargs: Any) -> None:
# prototype_initializer = kwargs.pop(
# "prototype_initializer",
# pt.components.Zeros(self.datashape),
# )
# super().__init__(*args,
# prototype_initializer=prototype_initializer,
# **kwargs)
# return DataClass
# if __name__ == "__main__":
# from prototorch.models import GLVQ
# demo_dataset = pt.datasets.Iris()
# TrainingClass: Type = train_on_dataset(demo_dataset)
# class DemoGLVQ(TrainingClass, GLVQ):
# """Model Definition."""
# # Hyperparameters
# hparams = dict(
# distribution={
# "num_classes": 3,
# "prototypes_per_class": 4
# },
# lr=0.01,
# )
# initialized = DemoGLVQ(hparams)
# print(initialized)

View File

@@ -5,8 +5,7 @@ Modules not yet available in prototorch go here temporarily.
"""
import torch
from ..core.similarities import gaussian
from prototorch.core.similarities import gaussian
def rank_scaled_gaussian(distances, lambd):

View File

@@ -1,15 +1,16 @@
"""Models based on the GLVQ framework."""
import torch
from prototorch.core.competitions import wtac
from prototorch.core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from prototorch.core.initializers import EyeTransformInitializer
from prototorch.core.losses import GLVQLoss, lvq1_loss, lvq21_loss
from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from ..core.initializers import EyeTransformInitializer
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
from ..core.transforms import LinearTransform
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .abstract import SupervisedPrototypeModel
from .mixin import ImagePrototypesMixin
class GLVQ(SupervisedPrototypeModel):
@@ -55,7 +56,7 @@ class GLVQ(SupervisedPrototypeModel):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x)
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
loss = self.loss(out, y, plabels)
return out, loss
@@ -112,7 +113,8 @@ class SiameseGLVQ(GLVQ):
proto_opt = self.optimizer(self.proto_layer.parameters(),
lr=self.hparams.proto_lr)
# Only add a backbone optimizer if backbone has trainable parameters
if (bb_params := list(self.backbone.parameters())):
bb_params = list(self.backbone.parameters())
if (bb_params):
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
optimizers = [proto_opt, bb_opt]
else:
@@ -129,7 +131,7 @@ class SiameseGLVQ(GLVQ):
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
@@ -250,6 +252,12 @@ class GMLVQ(GLVQ):
def omega_matrix(self):
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach() # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega)

View File

@@ -2,10 +2,11 @@
import warnings
from ..core.competitions import KNNC
from ..core.components import LabeledComponents
from ..core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
from ..utils.utils import parse_data_arg
from prototorch.core.competitions import KNNC
from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import LiteralCompInitializer, LiteralLabelsInitializer
from prototorch.utils.utils import parse_data_arg
from .abstract import SupervisedPrototypeModel

View File

@@ -1,18 +1,17 @@
"""LVQ models that are optimized using non-gradient methods."""
from ..core.losses import _get_dp_dm
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin
from prototorch.core.losses import _get_dp_dm
from prototorch.nn.activations import get_activation
from prototorch.nn.wrappers import LambdaLayer
from .glvq import GLVQ
from .mixin import NonGradientMixin
class LVQ1(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
protos, plables = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)
# TODO Vectorized implementation
@@ -30,8 +29,8 @@ class LVQ1(NonGradientMixin, GLVQ):
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
print(f"{dis=}")
print(f"{y=}")
print(f"dis={dis}")
print(f"y={y}")
# Logging
self.log_acc(dis, y, tag="train_acc")
@@ -41,8 +40,7 @@ class LVQ1(NonGradientMixin, GLVQ):
class LVQ21(NonGradientMixin, GLVQ):
"""Learning Vector Quantization 2.1."""
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
protos, plabels = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)
@@ -99,8 +97,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
return lower_bound
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
protos, plabels = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)

View File

@@ -0,0 +1,27 @@
class ProtoTorchMixin:
"""All mixins are ProtoTorchMixins."""
pass
class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes."""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Constrain the components to the range [0, 1] by clamping after updates."""
self.proto_layer.components.data.clamp_(0.0, 1.0)
def get_prototype_grid(self, num_columns=2, return_channels_last=True):
from torchvision.utils import make_grid
grid = make_grid(self.components, nrow=num_columns)
if return_channels_last:
grid = grid.permute((1, 2, 0))
return grid.cpu()

View File

@@ -1,10 +1,10 @@
"""Probabilistic GLVQ methods"""
import torch
from prototorch.core.losses import nllr_loss, rslvq_loss
from prototorch.core.pooling import stratified_min_pooling, stratified_sum_pooling
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from ..core.losses import nllr_loss, rslvq_loss
from ..core.pooling import stratified_min_pooling, stratified_sum_pooling
from ..nn.wrappers import LambdaLayer, LossLayer
from .extras import GaussianPrior, RankScaledGaussianPrior
from .glvq import GLVQ, SiameseGMLVQ
@@ -20,7 +20,7 @@ class CELVQ(GLVQ):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x) # [None, num_protos]
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning
batch_loss = self.loss(probs, y.long())
@@ -54,7 +54,7 @@ class ProbabilisticLVQ(GLVQ):
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.forward(x)
plabels = self.proto_layer.labels
_, plabels = self.proto_layer()
batch_loss = self.loss(out, y, plabels)
loss = batch_loss.sum()
return loss

View File

@@ -2,14 +2,15 @@
import numpy as np
import torch
from prototorch.core.competitions import wtac
from prototorch.core.distances import squared_euclidean_distance
from prototorch.core.losses import NeuralGasEnergy
from prototorch.nn.wrappers import LambdaLayer
from ..core.competitions import wtac
from ..core.distances import squared_euclidean_distance
from ..core.losses import NeuralGasEnergy
from ..nn.wrappers import LambdaLayer
from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
from .abstract import UnsupervisedPrototypeModel
from .callbacks import GNGCallback
from .extras import ConnectionTopology
from .mixin import NonGradientMixin
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
@@ -53,7 +54,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
grid = self._grid.view(-1, 2)
gd = squared_euclidean_distance(wp, grid)
nh = torch.exp(-gd / self._sigma**2)
protos = self.proto_layer.components
protos = self.proto_layer()
diff = x.unsqueeze(dim=1) - protos
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
updated_protos = protos + delta.sum(dim=0)

View File

@@ -5,15 +5,18 @@ import pytorch_lightning as pl
import torch
import torchvision
from matplotlib import pyplot as plt
from prototorch.utils.utils import generate_mesh, mesh2d
from torch.utils.data import DataLoader, Dataset
from ..utils.utils import mesh2d
COLOR_UNLABELED = 'w'
class Vis2DAbstract(pl.Callback):
def __init__(self,
data,
title="Prototype Visualization",
title=None,
x_label=None,
y_label=None,
cmap="viridis",
border=0.1,
resolution=100,
@@ -45,6 +48,8 @@ class Vis2DAbstract(pl.Callback):
self.y_train = y
self.title = title
self.x_label = x_label
self.y_label = y_label
self.fig = plt.figure(self.title)
self.cmap = cmap
self.border = border
@@ -57,20 +62,19 @@ class Vis2DAbstract(pl.Callback):
self.pause_time = pause_time
self.block = block
def precheck(self, trainer):
if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1:
return False
def show_on_current_epoch(self, trainer):
if self.show_last_only and trainer.current_epoch != trainer.max_epochs - 1:
return False
return True
def setup_ax(self, xlabel=None, ylabel=None):
def setup_ax(self):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
if xlabel:
ax.set_xlabel("Data dimension 1")
if ylabel:
ax.set_ylabel("Data dimension 2")
if self.x_label:
ax.set_xlabel(self.x_label)
if self.x_label:
ax.set_ylabel(self.y_label)
if self.axis_off:
ax.axis("off")
return ax
@@ -117,25 +121,64 @@ class Vis2DAbstract(pl.Callback):
plt.close()
class VisGLVQ2D(Vis2DAbstract):
class Visualize2DVoronoiCallback(Vis2DAbstract):
def __init__(self, data, **kwargs):
super().__init__(data, **kwargs)
self.data_min = torch.min(self.x_train, axis=0).values
self.data_max = torch.max(self.x_train, axis=0).values
def current_span(self, proto_values):
proto_min = torch.min(proto_values, axis=0).values
proto_max = torch.max(proto_values, axis=0).values
overall_min = torch.minimum(proto_min, self.data_min)
overall_max = torch.maximum(proto_max, self.data_max)
return overall_min, overall_max
def get_voronoi_diagram(self, min, max, model):
mesh_input, (xx, yy) = generate_mesh(
min,
max,
border=self.border,
resolution=self.resolution,
device=model.device,
)
y_pred = model.predict(mesh_input)
return xx, yy, y_pred.reshape(xx.shape)
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
if not self.show_on_current_epoch(trainer):
return True
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
# Extract Prototypes
proto_values = pl_module.prototypes
if hasattr(pl_module, "prototype_labels"):
proto_labels = pl_module.prototype_labels
else:
proto_labels = COLOR_UNLABELED
# Calculate Voronoi Diagram
overall_min, overall_max = self.current_span(proto_values)
xx, yy, y_pred = self.get_voronoi_diagram(
overall_min,
overall_max,
pl_module,
)
ax = self.setup_ax()
ax.contourf(
xx.cpu(),
yy.cpu(),
y_pred.cpu(),
cmap=self.cmap,
alpha=0.35,
)
self.plot_data(ax, self.x_train, self.y_train)
self.plot_protos(ax, proto_values, proto_labels)
self.log_and_display(trainer, pl_module)
@@ -146,7 +189,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
if not self.show_on_current_epoch(trainer):
return True
protos = pl_module.prototypes
@@ -178,40 +221,49 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
self.log_and_display(trainer, pl_module)
class VisCBC2D(Vis2DAbstract):
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
if not self.show_on_current_epoch(trainer):
return True
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos))
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.components_layer._components
y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
self.log_and_display(trainer, pl_module)
class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
if not self.show_on_current_epoch(trainer):
return True
x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
@@ -251,8 +303,6 @@ class VisImgComp(Vis2DAbstract):
size=self.embedding_data,
replace=False)
data = self.x_train[ind]
# print(f"{data.shape=}")
# print(f"{self.y_train[ind].shape=}")
tb.add_embedding(data.view(len(ind), -1),
label_img=data,
global_step=None,
@@ -284,7 +334,7 @@ class VisImgComp(Vis2DAbstract):
)
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
if not self.show_on_current_epoch(trainer):
return True
if self.show:

View File

@@ -18,11 +18,11 @@ PLUGIN_NAME = "models"
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh:
with open("README.md") as fh:
long_description = fh.read()
INSTALL_REQUIRES = [
"prototorch>=0.6.0",
"prototorch>=0.7.0",
"pytorch_lightning>=1.3.5",
"torchmetrics",
]
@@ -37,6 +37,7 @@ DOCS = [
"recommonmark",
"sphinx",
"nbsphinx",
"ipykernel",
"sphinx_rtd_theme",
"sphinxcontrib-katex",
"sphinxcontrib-bibtex",
@@ -53,7 +54,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="0.2.0",
version="0.3.0",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,
@@ -63,7 +64,7 @@ setup(
url=PROJECT_URL,
download_url=DOWNLOAD_URL,
license="MIT",
python_requires=">=3.9",
python_requires=">=3.6",
install_requires=INSTALL_REQUIRES,
extras_require={
"dev": DEV,
@@ -80,6 +81,9 @@ setup(
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.6",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries",

View File

@@ -1,11 +1,27 @@
#! /bin/bash
# Read Flags
gpu=0
while [ -n "$1" ]; do
case "$1" in
--gpu) gpu=1;;
-g) gpu=1;;
*) path=$1;;
esac
shift
done
python --version
echo "Using GPU: " $gpu
# Loop
failed=0
for example in $(find $1 -maxdepth 1 -name "*.py")
for example in $(find $path -maxdepth 1 -name "*.py")
do
echo -n "$x" $example '... '
export DISPLAY= && python $example --fast_dev_run 1 &> run_log.txt
export DISPLAY= && python $example --fast_dev_run 1 --gpus $gpu &> run_log.txt
if [[ $? -ne 0 ]]; then
echo "FAILED!!"
cat run_log.txt