Merge pull request #1 from si-cim/dev

Merge dev to main
This commit is contained in:
Jensun Ravichandran 2021-04-29 13:24:43 +02:00 committed by GitHub
commit c01c5d16db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 1684 additions and 72 deletions

View File

@ -17,14 +17,44 @@ pip install -e .
The plugin should then be available for use in your Python environment as
`prototorch.models`.
## Development setup
It is recommended that you use a virtual environment for development. If you do
not use `conda`, the easiest way to work with virtual environments is by using
[virtualenvwrapper](https://virtualenvwrapper.readthedocs.io/en/latest/). Once
you've installed it with `pip install virtualenvwrapper`, you can do the
following:
```sh
export WORKON_HOME=~/pyenvs
mkdir -p $WORKON_HOME
source /usr/local/bin/virtualenvwrapper.sh # might be different
# source ~/.local/bin/virtualenvwrapper.sh
mkvirtualenv pt
workon pt
git clone git@github.com:si-cim/prototorch_models.git
cd prototorch_models
git checkout dev
pip install -e .[all] # \[all\] if you are using zsh
```
To assist in the development process, you may also find it useful to install
`yapf`, `isort` and `autoflake`. You can install them easily with `pip`.
## Available models
- [X] GLVQ
- [ ] GMLVQ
- [ ] Local-Matrix GMLVQ
- [ ] Limited-Rank GMLVQ
- [ ] GTLVQ
- [ ] RSLVQ
- [ ] PLVQ
- [ ] LVQMLN
- [ ] CBC
- GLVQ
- Siamese GLVQ
- Neural Gas
## Work in Progress
- CBC
## Planned models
- GMLVQ
- Local-Matrix GMLVQ
- Limited-Rank GMLVQ
- GTLVQ
- RSLVQ
- PLVQ
- LVQMLN

129
examples/cbc_circle.py Normal file
View File

@ -0,0 +1,129 @@
"""CBC example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from sklearn.datasets import make_circles
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.callbacks.visualization import VisPointProtos
from prototorch.models.cbc import CBC, euclidean_similarity
from prototorch.models.glvq import GLVQ
class VisualizationCallback(pl.Callback):
def __init__(
self,
x_train,
y_train,
prototype_model=True,
title="Prototype Visualization",
cmap="viridis",
):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
self.prototype_model = prototype_model
def on_epoch_end(self, trainer, pl_module):
if self.prototype_model:
protos = pl_module.prototypes
color = pl_module.prototype_labels
else:
protos = pl_module.components
color = "k"
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=color,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
if __name__ == "__main__":
# Dataset
x_train, y_train = make_circles(n_samples=300,
shuffle=True,
noise=0.05,
random_state=None,
factor=0.5)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=len(np.unique(y_train)),
prototypes_per_class=5,
prototype_initializer="randn",
lr=0.01,
)
# Initialize the model
model = CBC(
hparams,
data=[x_train, y_train],
similarity=euclidean_similarity,
)
model = GLVQ(hparams, data=[x_train, y_train])
# Fix the component locations
# model.proto_layer.requires_grad_(False)
# import sys
# sys.exit()
# Model summary
print(model)
# Callbacks
dvis = VisPointProtos(
data=(x_train, y_train),
save=True,
snap=False,
voronoi=True,
resolution=50,
pause_time=0.1,
make_gif=True,
)
# Setup trainer
trainer = pl.Trainer(
max_epochs=10,
callbacks=[
dvis,
],
)
# Training loop
trainer.fit(model, train_loader)

114
examples/cbc_iris.py Normal file
View File

@ -0,0 +1,114 @@
"""CBC example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.cbc import CBC
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
# protos = pl_module.prototypes
protos = pl_module.components
# plabels = pl_module.prototype_labels
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
# c=plabels,
c="k",
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
prototypes_per_class=3,
prototype_initializer="stratified_mean",
lr=0.01,
)
# Initialize the model
model = CBC(hparams, data=[x_train, y_train])
# Fix the component locations
# model.proto_layer.requires_grad_(False)
# Pure-positive reasonings
ncomps = 3
nclasses = 3
rmat = torch.stack(
[0.9 * torch.eye(ncomps),
torch.zeros(ncomps, nclasses)], dim=0)
# model.reasoning_layer.load_state_dict({"reasoning_probabilities": rmat},
# strict=True)
print(model.reasoning_layer.reasoning_probabilities)
# import sys
# sys.exit()
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(
max_epochs=100,
callbacks=[
vis,
],
)
# Training loop
trainer.fit(model, train_loader)

128
examples/cbc_mnist.py Normal file
View File

@ -0,0 +1,128 @@
"""CBC example using the MNIST dataset.
This script also shows how to use Tensorboard for visualizing the prototypes.
"""
import argparse
import pytorch_lightning as pl
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from prototorch.models.cbc import CBC, ImageCBC, euclidean_similarity
class VisualizationCallback(pl.Callback):
def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2):
super().__init__()
self.to_shape = to_shape
self.nrow = nrow
def on_epoch_end(self, trainer, pl_module: ImageCBC):
tb = pl_module.logger.experiment
# components
components = pl_module.components
components_img = components.reshape(self.to_shape)
grid = torchvision.utils.make_grid(components_img, nrow=self.nrow)
tb.add_image(
tag="MNIST Components",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats="CHW",
)
# Reasonings
reasonings = pl_module.reasonings
tb.add_images(
tag="MNIST Reasoning",
img_tensor=reasonings,
global_step=trainer.current_epoch,
dataformats="NCHW",
)
if __name__ == "__main__":
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--epochs",
type=int,
default=10,
help="Epochs to train.")
parser.add_argument("--lr",
type=float,
default=0.001,
help="Learning rate.")
parser.add_argument("--batch_size",
type=int,
default=256,
help="Batch size.")
parser.add_argument("--gpus",
type=int,
default=0,
help="Number of GPUs to use.")
parser.add_argument("--ppc",
type=int,
default=1,
help="Prototypes-Per-Class.")
args = parser.parse_args()
# Dataset
mnist_train = MNIST(
"./datasets",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]),
)
mnist_test = MNIST(
"./datasets",
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]),
)
# Dataloaders
train_loader = DataLoader(mnist_train, batch_size=32)
test_loader = DataLoader(mnist_test, batch_size=32)
# Grab the full dataset to warm-start prototypes
x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
x = x.view(len(mnist_train), -1)
# Hyperparameters
hparams = dict(
input_dim=28 * 28,
nclasses=10,
prototypes_per_class=args.ppc,
prototype_initializer="randn",
lr=0.01,
similarity=euclidean_similarity,
)
# Initialize the model
model = CBC(hparams, data=[x, y])
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc)
# Setup trainer
trainer = pl.Trainer(
gpus=args.gpus, # change to use GPUs for training
max_epochs=args.epochs,
callbacks=[vis],
track_grad_norm=2,
# accelerator="ddp_cpu", # DEBUG-ONLY
# num_processes=2, # DEBUG-ONLY
)
# Training loop
trainer.fit(model, train_loader, test_loader)

135
examples/cbc_spiral.py Normal file
View File

@ -0,0 +1,135 @@
"""CBC example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.cbc import CBC
class VisualizationCallback(pl.Callback):
def __init__(
self,
x_train,
y_train,
prototype_model=True,
title="Prototype Visualization",
cmap="viridis",
):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
self.prototype_model = prototype_model
def on_epoch_end(self, trainer, pl_module):
if self.prototype_model:
protos = pl_module.prototypes
color = pl_module.prototype_labels
else:
protos = pl_module.components
color = "k"
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=color,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
def make_spirals(n_samples=500, noise=0.3):
def get_samples(n, delta_t):
points = []
for i in range(n):
r = i / n_samples * 5
t = 1.75 * i / n * 2 * np.pi + delta_t
x = r * np.sin(t) + np.random.rand(1) * noise
y = r * np.cos(t) + np.random.rand(1) * noise
points.append([x, y])
return points
n = n_samples // 2
positive = get_samples(n=n, delta_t=0)
negative = get_samples(n=n, delta_t=np.pi)
x = np.concatenate(
[np.array(positive).reshape(n, -1),
np.array(negative).reshape(n, -1)],
axis=0)
y = np.concatenate([np.zeros(n), np.ones(n)])
return x, y
if __name__ == "__main__":
# Dataset
x_train, y_train = make_spirals(n_samples=1000, noise=0.3)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=2,
prototypes_per_class=40,
prototype_initializer="stratified_random",
lr=0.05,
)
# Initialize the model
model_class = CBC
model = model_class(hparams, data=[x_train, y_train])
# Pure-positive reasonings
new_reasoning = torch.zeros_like(
model.reasoning_layer.reasoning_probabilities)
for i, label in enumerate(model.proto_layer.prototype_labels):
new_reasoning[0][0][i][int(label)] = 1.0
model.reasoning_layer.reasoning_probabilities.data = new_reasoning
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train,
y_train,
prototype_model=hasattr(model, "prototypes"))
# Setup trainer
trainer = pl.Trainer(
max_epochs=500,
callbacks=[
vis,
],
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -0,0 +1,146 @@
"""CBC example using the spirals dataset.
This example shows how to jump start a model by transferring weights from
another more stable model.
"""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.cbc import CBC
from prototorch.models.glvq import GLVQ
class VisualizationCallback(pl.Callback):
def __init__(
self,
x_train,
y_train,
prototype_model=True,
title="Prototype Visualization",
cmap="viridis",
):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
self.prototype_model = prototype_model
def on_epoch_end(self, trainer, pl_module):
if self.prototype_model:
protos = pl_module.prototypes
color = pl_module.prototype_labels
else:
protos = pl_module.components
color = "k"
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=color,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
def make_spirals(n_samples=500, noise=0.3):
def get_samples(n, delta_t):
points = []
for i in range(n):
r = i / n_samples * 5
t = 1.75 * i / n * 2 * np.pi + delta_t
x = r * np.sin(t) + np.random.rand(1) * noise
y = r * np.cos(t) + np.random.rand(1) * noise
points.append([x, y])
return points
n = n_samples // 2
positive = get_samples(n=n, delta_t=0)
negative = get_samples(n=n, delta_t=np.pi)
x = np.concatenate(
[np.array(positive).reshape(n, -1),
np.array(negative).reshape(n, -1)],
axis=0)
y = np.concatenate([np.zeros(n), np.ones(n)])
return x, y
def train(model, x_train, y_train, train_loader, epochs=100):
# Callbacks
vis = VisualizationCallback(x_train,
y_train,
prototype_model=hasattr(model, "prototypes"))
# Setup trainer
trainer = pl.Trainer(
max_epochs=epochs,
callbacks=[
vis,
],
)
# Training loop
trainer.fit(model, train_loader)
if __name__ == "__main__":
# Dataset
x_train, y_train = make_spirals(n_samples=1000, noise=0.3)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=2,
prototypes_per_class=40,
prototype_initializer="stratified_random",
lr=0.05,
)
# Initialize the model
glvq_model = GLVQ(hparams, data=[x_train, y_train])
cbc_model = CBC(hparams, data=[x_train, y_train])
# Train GLVQ
train(glvq_model, x_train, y_train, train_loader, epochs=10)
# Transfer Prototypes
cbc_model.proto_layer.load_state_dict(glvq_model.proto_layer.state_dict())
# Pure-positive reasonings
new_reasoning = torch.zeros_like(
cbc_model.reasoning_layer.reasoning_probabilities)
for i, label in enumerate(cbc_model.proto_layer.prototype_labels):
new_reasoning[0][0][i][int(label)] = 1.0
new_reasoning[1][0][i][1 - int(label)] = 1.0
cbc_model.reasoning_layer.reasoning_probabilities.data = new_reasoning
# Train CBC
train(cbc_model, x_train, y_train, train_loader, epochs=50)

View File

@ -1,18 +1,33 @@
"""GLVQ example using the Iris dataset."""
import argparse
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from prototorch.models.glvq import GLVQ
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import GLVQ
class NumpyDataset(TensorDataset):
def __init__(self, *arrays):
tensors = [torch.from_numpy(arr) for arr in arrays]
super().__init__(*tensors)
class GLVQIris(GLVQ):
@staticmethod
def add_model_specific_args(parent_parser):
parser = argparse.ArgumentParser(parents=[parent_parser],
add_help=False)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-1)
parser.add_argument("--batch_size", type=int, default=150)
parser.add_argument("--input_dim", type=int, default=2)
parser.add_argument("--nclasses", type=int, default=3)
parser.add_argument("--prototypes_per_class", type=int, default=3)
parser.add_argument("--prototype_initializer",
type=str,
default="stratified_mean")
return parser
class VisualizationCallback(pl.Callback):
@ -37,13 +52,15 @@ class VisualizationCallback(pl.Callback):
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50)
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
@ -60,6 +77,10 @@ class VisualizationCallback(pl.Callback):
if __name__ == "__main__":
# For best-practices when using `argparse` with `pytorch_lightning`, see
# https://pytorch-lightning.readthedocs.io/en/stable/common/hyperparameters.html
parser = argparse.ArgumentParser()
# Dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
@ -68,29 +89,43 @@ if __name__ == "__main__":
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Initialize the model
model = GLVQ(
input_dim=x_train.shape[1],
nclasses=3,
prototypes_per_class=3,
prototype_initializer="stratified_mean",
data=[x_train, y_train],
lr=0.1,
)
# Model summary
print(model)
# Add model specific args
parser = GLVQIris.add_model_specific_args(parser)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
# Automatically add trainer-specific-args like `--gpus`, `--num_nodes` etc.
parser = pl.Trainer.add_argparse_args(parser)
# Setup trainer
trainer = pl.Trainer(max_epochs=1000, callbacks=[vis])
trainer = pl.Trainer.from_argparse_args(
parser,
max_epochs=10,
callbacks=[
vis,
], # comment this line out to disable the visualization
)
# trainer.tune(model)
# Initialize the model
args = parser.parse_args()
model = GLVQIris(args, data=[x_train, y_train])
# Model summary
print(model)
# Training loop
trainer.fit(model, train_loader)
# Visualization
protos = model.prototypes
plabels = model.prototype_labels
visualize(x_train, y_train, protos, plabels)
# Save the model manually (use `pl.callbacks.ModelCheckpoint` to automate)
ckpt = "glvq_iris.ckpt"
trainer.save_checkpoint(ckpt)
# Load the checkpoint
new_model = GLVQIris.load_from_checkpoint(checkpoint_path=ckpt)
print(new_model)
# Continue training
trainer.fit(new_model, train_loader) # TODO See why this fails!

92
examples/glvq_iris_v1.py Normal file
View File

@ -0,0 +1,92 @@
"""GLVQ example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import GLVQ
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
prototypes_per_class=3,
prototype_initializer="stratified_mean",
lr=0.1,
)
# Initialize the model
model = GLVQ(hparams, data=[x_train, y_train])
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(max_epochs=50, callbacks=[vis])
# Training loop
trainer.fit(model, train_loader)

View File

@ -1,9 +1,5 @@
"""GLVQ example using the MNIST dataset.
TODO
- Add model serialization/deserialization
- Add evaluation metrics
This script also shows how to use Tensorboard for visualizing the prototypes.
"""
@ -11,13 +7,12 @@ import argparse
import pytorch_lightning as pl
import torchvision
from matplotlib import pyplot as plt
from prototorch.functions.initializers import stratified_mean
from prototorch.models.glvq import ImageGLVQ
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from prototorch.models.glvq import ImageGLVQ
class VisualizationCallback(pl.Callback):
def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2):
@ -31,10 +26,12 @@ class VisualizationCallback(pl.Callback):
grid = torchvision.utils.make_grid(protos_img, nrow=self.nrow)
# grid = grid.permute((1, 2, 0))
tb = pl_module.logger.experiment
tb.add_image(tag="MNIST Prototypes",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats="CHW")
tb.add_image(
tag="MNIST Prototypes",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats="CHW",
)
if __name__ == "__main__":
@ -90,12 +87,18 @@ if __name__ == "__main__":
x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
x = x.view(len(mnist_train), -1)
# Hyperparameters
hparams = dict(
input_dim=28 * 28,
nclasses=10,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
lr=args.lr,
)
# Initialize the model
model = ImageGLVQ(input_dim=28 * 28,
nclasses=10,
prototypes_per_class=args.ppc,
prototype_initializer="stratified_mean",
data=[x, y])
model = ImageGLVQ(hparams, data=[x, y])
# Model summary
print(model)

104
examples/ng_iris.py Normal file
View File

@ -0,0 +1,104 @@
"""Neural Gas example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.neural_gas import NeuralGas
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Neural Gas Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module: NeuralGas):
protos = pl_module.proto_layer.prototypes.detach().cpu().numpy()
cmat = pl_module.topology_layer.cmat.cpu().numpy()
# Visualize the data and the prototypes
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(self.x_train[:, 0],
self.x_train[:, 1],
c=self.y_train,
edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c="k",
edgecolor="k",
marker="D",
s=50,
)
# Draw connections
for i in range(len(protos)):
for j in range(len(protos)):
if cmat[i][j]:
ax.plot(
[protos[i, 0], protos[j, 0]],
[protos[i, 1], protos[j, 1]],
"k-",
)
plt.pause(0.01)
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
scaler = StandardScaler()
scaler.fit(x_train)
x_train = scaler.transform(x_train)
y_single_class = np.zeros_like(y_train)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=1,
prototypes_per_class=30,
prototype_initializer="rand",
lr=0.1,
)
# Initialize the model
model = NeuralGas(hparams, data=[x_train, y_single_class])
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(
max_epochs=100,
callbacks=[
vis,
],
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -0,0 +1,115 @@
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.glvq import SiameseGLVQ
class VisualizationCallback(pl.Callback):
def __init__(self,
x_train,
y_train,
title="Prototype Visualization",
cmap="viridis"):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
def on_epoch_end(self, trainer, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
protos = pl_module.backbone(torch.Tensor(protos)).detach()
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.axis("off")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 0.2, x[:, 0].max() + 0.2
y_min, y_max = x[:, 1].min() - 0.2, x[:, 1].max() + 0.2
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
tb = pl_module.logger.experiment
tb.add_figure(
tag=f"{self.title}",
figure=self.fig,
global_step=trainer.current_epoch,
close=False,
)
plt.pause(0.1)
class Backbone(torch.nn.Module):
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.latent_size = latent_size
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.dense2(self.relu(self.dense1(x))))
if __name__ == "__main__":
# Dataset
x_train, y_train = load_iris(return_X_y=True)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
prototypes_per_class=1,
prototype_initializer="stratified_mean",
lr=0.01,
)
# Initialize the model
model = SiameseGLVQ(hparams,
backbone_module=Backbone,
data=[x_train, y_train])
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train, y_train)
# Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
# Training loop
trainer.fit(model, train_loader)

View File

@ -1,8 +1,8 @@
from importlib.metadata import version, PackageNotFoundError
from importlib.metadata import PackageNotFoundError, version
VERSION_FALLBACK = "uninstalled_version"
try:
__version__ = version(__name__.replace(".", "-"))
except PackageNotFoundError:
__version__ = VERSION_FALLBACK
pass
pass

View File

@ -0,0 +1,260 @@
import os
import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText
from prototorch.utils.celluloid import Camera
from prototorch.utils.colors import color_scheme
from prototorch.utils.utils import gif_from_dir, make_directory, prettify_string
class VisWeights(Callback):
"""Abstract weight visualization callback."""
def __init__(
self,
data=None,
ignore_last_output_row=False,
label_map=None,
project_mesh=False,
project_protos=False,
voronoi=False,
axis_off=True,
cmap="viridis",
show=True,
display_logs=True,
display_logs_settings={},
pause_time=0.5,
border=1,
resolution=10,
interval=False,
save=False,
snap=True,
save_dir="./img",
make_gif=False,
make_mp4=False,
verbose=True,
dpi=500,
fps=5,
figsize=(11, 8.5), # standard paper in inches
prefix="",
distance_layer_index=-1,
**kwargs,
):
super().__init__(**kwargs)
self.data = data
self.ignore_last_output_row = ignore_last_output_row
self.label_map = label_map
self.voronoi = voronoi
self.axis_off = True
self.project_mesh = project_mesh
self.project_protos = project_protos
self.cmap = cmap
self.show = show
self.display_logs = display_logs
self.display_logs_settings = display_logs_settings
self.pause_time = pause_time
self.border = border
self.resolution = resolution
self.interval = interval
self.save = save
self.snap = snap
self.save_dir = save_dir
self.make_gif = make_gif
self.make_mp4 = make_mp4
self.verbose = verbose
self.dpi = dpi
self.fps = fps
self.figsize = figsize
self.prefix = prefix
self.distance_layer_index = distance_layer_index
self.title = "Weights Visualization"
make_directory(self.save_dir)
def _skip_epoch(self, epoch):
if self.interval:
if epoch % self.interval != 0:
return True
return False
def _clean_and_setup_ax(self):
ax = self.ax
if not self.snap:
ax.cla()
ax.set_title(self.title)
if self.axis_off:
ax.axis("off")
def _savefig(self, fignum, orientation="horizontal"):
figname = f"{self.save_dir}/{self.prefix}{fignum:05d}.png"
figsize = self.figsize
if orientation == "vertical":
figsize = figsize[::-1]
elif orientation == "horizontal":
pass
else:
pass
self.fig.set_size_inches(figsize, forward=False)
self.fig.savefig(figname, dpi=self.dpi)
def _show_and_save(self, epoch):
if self.show:
plt.pause(self.pause_time)
if self.save:
self._savefig(epoch)
if self.snap:
self.camera.snap()
def _display_logs(self, ax, epoch, logs):
if self.display_logs:
settings = dict(
loc="lower right",
# padding between the text and bounding box
pad=0.5,
# padding between the bounding box and the axes
borderpad=1.0,
# https://matplotlib.org/api/text_api.html#matplotlib.text.Text
prop=dict(
fontfamily="monospace",
fontweight="medium",
fontsize=12,
),
)
# Override settings with self.display_logs_settings.
settings = {**settings, **self.display_logs_settings}
log_string = f"""Epoch: {epoch:04d},
val_loss: {logs.get('val_loss', np.nan):.03f},
val_acc: {logs.get('val_acc', np.nan):.03f},
loss: {logs.get('loss', np.nan):.03f},
acc: {logs.get('acc', np.nan):.03f}
"""
log_string = prettify_string(log_string, end="")
# https://matplotlib.org/api/offsetbox_api.html#matplotlib.offsetbox.AnchoredText
anchored_text = AnchoredText(log_string, **settings)
self.ax.add_artist(anchored_text)
def on_train_start(self, trainer, pl_module, logs={}):
self.fig = plt.figure(self.title)
self.fig.set_size_inches(self.figsize, forward=False)
self.ax = self.fig.add_subplot(111)
self.camera = Camera(self.fig)
def on_train_end(self, trainer, pl_module, logs={}):
if self.make_gif:
gif_from_dir(directory=self.save_dir,
prefix=self.prefix,
duration=1.0 / self.fps)
if self.snap and self.make_mp4:
animation = self.camera.animate()
vid = os.path.join(self.save_dir, f"{self.prefix}animation.mp4")
if self.verbose:
print(f"Saving mp4 under {vid}.")
animation.save(vid, fps=self.fps, dpi=self.dpi)
class VisPointProtos(VisWeights):
"""Visualization of prototypes.
.. TODO::
Still in Progress.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.title = "Point Prototypes Visualization"
self.data_scatter_settings = {
"marker": "o",
"s": 30,
"edgecolor": "k",
"cmap": self.cmap,
}
self.protos_scatter_settings = {
"marker": "D",
"s": 50,
"edgecolor": "k",
"cmap": self.cmap,
}
def on_epoch_start(self, trainer, pl_module, logs={}):
epoch = trainer.current_epoch
if self._skip_epoch(epoch):
return True
self._clean_and_setup_ax()
protos = pl_module.prototypes
labels = pl_module.proto_layer.prototype_labels.detach().cpu().numpy()
if self.project_protos:
protos = self.model.projection(protos).numpy()
color_map = color_scheme(n=len(set(labels)),
cmap=self.cmap,
zero_indexed=True)
# TODO Get rid of the assumption y values in [0, num_of_classes]
label_colors = [color_map[l] for l in labels]
if self.data is not None:
x, y = self.data
# TODO Get rid of the assumption y values in [0, num_of_classes]
y_colors = [color_map[l] for l in y]
# x = self.model.projection(x)
if not isinstance(x, np.ndarray):
x = x.numpy()
# Plot data points.
self.ax.scatter(x[:, 0],
x[:, 1],
c=y_colors,
**self.data_scatter_settings)
# Paint decision regions.
if self.voronoi:
border = self.border
resolution = self.resolution
x = np.vstack((x, protos))
x_min, x_max = x[:, 0].min(), x[:, 0].max()
y_min, y_max = x[:, 1].min(), x[:, 1].max()
x_min, x_max = x_min - border, x_max + border
y_min, y_max = y_min - border, y_max + border
try:
xx, yy = np.meshgrid(
np.arange(x_min, x_max, (x_max - x_min) / resolution),
np.arange(y_min, y_max, (x_max - x_min) / resolution),
)
except ValueError as ve:
print(ve)
raise ValueError(f"x_min: {x_min}, x_max: {x_max}. "
f"x_min - x_max is {x_max - x_min}.")
except MemoryError as me:
print(me)
raise ValueError("Too many points. "
"Try reducing the resolution.")
mesh_input = np.c_[xx.ravel(), yy.ravel()]
# Predict mesh labels.
if self.project_mesh:
mesh_input = self.model.projection(mesh_input)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions.
self.ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.ax.set_xlim(left=x_min + 0, right=x_max - 0)
self.ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
# Plot prototypes.
self.ax.scatter(protos[:, 0],
protos[:, 1],
c=label_colors,
**self.protos_scatter_settings)
# self._show_and_save(epoch)
def on_epoch_end(self, trainer, pl_module, logs={}):
epoch = trainer.current_epoch
self._display_logs(self.ax, epoch, logs)
self._show_and_save(epoch)

170
prototorch/models/cbc.py Normal file
View File

@ -0,0 +1,170 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity
from prototorch.modules.prototypes import Prototypes1D
def rescaled_cosine_similarity(x, y):
"""Cosine Similarity rescaled to [0, 1]."""
similarities = cosine_similarity(x, y)
return (similarities + 1.0) / 2.0
def shift_activation(x):
return (x + 1.0) / 2.0
def euclidean_similarity(x, y):
d = euclidean_distance(x, y)
return torch.exp(-d * 3)
class CosineSimilarity(torch.nn.Module):
def __init__(self, activation=shift_activation):
super().__init__()
self.activation = activation
def forward(self, x, y):
epsilon = torch.finfo(x.dtype).eps
normed_x = (x / x.pow(2).sum(dim=tuple(range(
1, x.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
normed_y = (y / y.pow(2).sum(dim=tuple(range(
1, y.ndim)), keepdim=True).clamp(min=epsilon).sqrt()).flatten(
start_dim=1)
# normed_x = (x / torch.linalg.norm(x, dim=1))
diss = torch.inner(normed_x, normed_y)
return self.activation(diss)
class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self,
margin=0.3,
size_average=None,
reduce=None,
reduction="mean"):
super().__init__(size_average, reduce, reduction)
self.margin = margin
def forward(self, input_, target):
dp = torch.sum(target * input_, dim=-1)
dm = torch.max(input_ - target, dim=-1).values
return torch.nn.functional.relu(dm - dp + self.margin)
class ReasoningLayer(torch.nn.Module):
def __init__(self, n_components, n_classes, n_replicas=1):
super().__init__()
self.n_replicas = n_replicas
self.n_classes = n_classes
probabilities_init = torch.zeros(2, 1, n_components, self.n_classes)
probabilities_init.uniform_(0.4, 0.6)
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@property
def reasonings(self):
pk = self.reasoning_probabilities[0]
nk = (1 - pk) * self.reasoning_probabilities[1]
ik = 1 - pk - nk
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1)
def forward(self, detections):
pk = self.reasoning_probabilities[0].clamp(0, 1)
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
epsilon = torch.finfo(pk.dtype).eps
numerator = (detections @ (pk - nk)) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
probs = probs.squeeze(0)
return probs
class CBC(pl.LightningModule):
"""Classification-By-Components."""
def __init__(self,
hparams,
margin=0.1,
backbone_class=torch.nn.Identity,
similarity=euclidean_similarity,
**kwargs):
super().__init__()
self.save_hyperparameters(hparams)
self.margin = margin
self.proto_layer = Prototypes1D(
input_dim=self.hparams.input_dim,
nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs)
# self.similarity = CosineSimilarity()
self.similarity = similarity
self.backbone = backbone_class()
self.backbone_dependent = backbone_class().requires_grad_(False)
n_components = self.components.shape[0]
self.reasoning_layer = ReasoningLayer(n_components=n_components,
n_classes=self.hparams.nclasses)
self.train_acc = torchmetrics.Accuracy()
@property
def components(self):
return self.proto_layer.prototypes.detach().cpu()
@property
def reasonings(self):
return self.reasoning_layer.reasonings.cpu()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer
def sync_backbones(self):
master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True)
def forward(self, x):
self.sync_backbones()
protos, _ = self.proto_layer()
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
detections = self.similarity(latent_x, latent_protos)
probs = self.reasoning_layer(detections)
return probs
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
y_pred = self(x)
nclasses = self.reasoning_layer.n_classes
y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
self.log("train_loss", loss)
self.train_acc(y_pred, y_true)
self.log(
"acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
def predict(self, x):
with torch.no_grad():
y_pred = self(x)
y_pred = torch.argmax(y_pred, dim=1)
return y_pred.numpy()
class ImageCBC(CBC):
"""CBC model that constrains the components to the range [0, 1] by
clamping after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
# super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)

View File

@ -1,18 +1,31 @@
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.initializers import get_initializer
from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D
class GLVQ(pl.LightningModule):
"""Generalized Learning Vector Quantization."""
def __init__(self, lr=1e-3, **kwargs):
def __init__(self, hparams, **kwargs):
super().__init__()
self.lr = lr
self.proto_layer = Prototypes1D(**kwargs)
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("distance", euclidean_distance)
self.proto_layer = Prototypes1D(
input_dim=self.hparams.input_dim,
nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs)
self.train_acc = torchmetrics.Accuracy()
@property
def prototypes(self):
@ -22,15 +35,15 @@ class GLVQ(pl.LightningModule):
def prototype_labels(self):
return self.proto_layer.prototype_labels.detach().numpy()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer
def forward(self, x):
protos = self.proto_layer.prototypes
dis = euclidean_distance(x, protos)
dis = self.hparams.distance(x, protos)
return dis
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
@ -39,9 +52,29 @@ class GLVQ(pl.LightningModule):
mu = glvq_loss(dis, y, prototype_labels=plabels)
loss = mu.sum(dim=0)
self.log("train_loss", loss)
with torch.no_grad():
preds = wtac(dis, plabels)
# self.train_acc.update(preds.int(), y.int())
self.train_acc(
preds.int(),
y.int()) # FloatTensors are assumed to be class probabilities
self.log(
"acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
# def training_epoch_end(self, outs):
# # Calling `self.train_acc.compute()` is
# # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)`
# self.log("train_acc_epoch", self.train_acc.compute())
def predict(self, x):
# model.eval() # ?!
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.prototype_labels
@ -50,8 +83,52 @@ class GLVQ(pl.LightningModule):
class ImageGLVQ(GLVQ):
"""GLVQ model that constrains the prototypes to the range [0, 1] by
"""GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by
clamping after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.prototypes.data.clamp_(0., 1.)
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting.
GLVQ model that applies an arbitrary transformation on the inputs and the
prototypes before computing the distances between them. The weights in the
transformation pipeline are only learned from the inputs.
"""
def __init__(self,
hparams,
backbone_module=torch.nn.Identity,
backbone_params={},
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
self.backbone_dependent = backbone_module(
**backbone_params).requires_grad_(False)
def sync_backbones(self):
master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True)
def forward(self, x):
self.sync_backbones()
protos = self.proto_layer.prototypes
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
dis = euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
# model.eval() # ?!
with torch.no_grad():
protos = self.proto_layer.prototypes
latent_protos = self.backbone_dependent(protos)
d = euclidean_distance(x, latent_protos)
plabels = self.proto_layer.prototype_labels
y_pred = wtac(d, plabels)
return y_pred.numpy()

View File

@ -0,0 +1,74 @@
import pytorch_lightning as pl
import torch
from prototorch.functions.distances import euclidean_distance
from prototorch.modules import Prototypes1D
from prototorch.modules.losses import NeuralGasEnergy
class EuclideanDistance(torch.nn.Module):
def forward(self, x, y):
return euclidean_distance(x, y)
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()
self.agelimit = agelimit
self.num_prototypes = num_prototypes
self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes))
self.age = torch.zeros_like(self.cmat)
def forward(self, d):
order = torch.argsort(d, dim=1)
for element in order:
i0, i1 = element[0], element[1]
self.cmat[i0][i1] = 1
self.age[i0][i1] = 0
self.age[i0][self.cmat[i0] == 1] += 1
self.cmat[i0][self.age[i0] > self.agelimit] = 0
def extra_repr(self):
return f"agelimit: {self.agelimit}"
class NeuralGas(pl.LightningModule):
def __init__(self, hparams, **kwargs):
super().__init__()
self.save_hyperparameters(hparams)
# Default Values
self.hparams.setdefault("agelimit", 10)
self.hparams.setdefault("lm", 1)
self.hparams.setdefault("prototype_initializer", "zeros")
self.proto_layer = Prototypes1D(
input_dim=self.hparams.input_dim,
nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
**kwargs,
)
self.distance_layer = EuclideanDistance()
self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm)
self.topology_layer = ConnectionTopology(
agelimit=self.hparams.agelimit,
num_prototypes=len(self.proto_layer.prototypes),
)
def training_step(self, train_batch, batch_idx):
x, _ = train_batch
protos, _ = self.proto_layer()
d = self.distance_layer(x, protos)
cost, order = self.energy_layer(d)
self.topology_layer(d)
return cost
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer

View File

@ -9,8 +9,7 @@
ProtoTorch models Plugin Package
"""
from pkg_resources import safe_name
from setuptools import setup
from setuptools import find_namespace_packages
from setuptools import find_namespace_packages, setup
PLUGIN_NAME = "models"
@ -20,7 +19,7 @@ DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh:
long_description = fh.read()
INSTALL_REQUIRES = ["prototorch", "pytorch_lightning"]
INSTALL_REQUIRES = ["prototorch", "pytorch_lightning", "torchmetrics"]
EXAMPLES = ["matplotlib", "scikit-learn"]
TESTS = ["pytest"]
ALL = EXAMPLES + TESTS
@ -28,7 +27,8 @@ ALL = EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
use_scm_version=True,
descripion="Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning.",
descripion=
"Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,
author="Alexander Engelsberger",
author_email="engelsbe@hs-mittweida.de",