Automatic Formating.

This commit is contained in:
Alexander Engelsberger 2021-04-23 17:27:47 +02:00
parent db4499a103
commit c4c51a16fe
12 changed files with 404 additions and 159 deletions

View File

@ -4,26 +4,24 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.models.cbc import CBC, rescaled_cosine_similarity, euclidean_similarity
from prototorch.models.glvq import GLVQ
from sklearn.datasets import make_circles from sklearn.datasets import make_circles
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
class NumpyDataset(TensorDataset): from prototorch.models.callbacks.visualization import VisPointProtos
def __init__(self, *arrays): from prototorch.models.cbc import CBC, euclidean_similarity
# tensors = [torch.from_numpy(arr) for arr in arrays] from prototorch.models.glvq import GLVQ
tensors = [torch.Tensor(arr) for arr in arrays]
super().__init__(*tensors)
class VisualizationCallback(pl.Callback): class VisualizationCallback(pl.Callback):
def __init__(self, def __init__(
x_train, self,
y_train, x_train,
prototype_model=True, y_train,
title="Prototype Visualization", prototype_model=True,
cmap="viridis"): title="Prototype Visualization",
cmap="viridis",
):
super().__init__() super().__init__()
self.x_train = x_train self.x_train = x_train
self.y_train = y_train self.y_train = y_train
@ -38,20 +36,22 @@ class VisualizationCallback(pl.Callback):
color = pl_module.prototype_labels color = pl_module.prototype_labels
else: else:
protos = pl_module.components protos = pl_module.components
color = 'k' color = "k"
ax = self.fig.gca() ax = self.fig.gca()
ax.cla() ax.cla()
ax.set_title(self.title) ax.set_title(self.title)
ax.set_xlabel("Data dimension 1") ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2") ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(protos[:, 0], ax.scatter(
protos[:, 1], protos[:, 0],
c=color, protos[:, 1],
cmap=self.cmap, c=color,
edgecolor="k", cmap=self.cmap,
marker="D", edgecolor="k",
s=50) marker="D",
s=50,
)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
@ -95,7 +95,7 @@ if __name__ == "__main__":
similarity=euclidean_similarity, similarity=euclidean_similarity,
) )
#model = GLVQ(hparams, data=[x_train, y_train]) model = GLVQ(hparams, data=[x_train, y_train])
# Fix the component locations # Fix the component locations
# model.proto_layer.requires_grad_(False) # model.proto_layer.requires_grad_(False)
@ -107,13 +107,21 @@ if __name__ == "__main__":
print(model) print(model)
# Callbacks # Callbacks
vis = VisualizationCallback(x_train, y_train, prototype_model=False) dvis = VisPointProtos(
data=(x_train, y_train),
save=True,
snap=False,
voronoi=True,
resolution=50,
pause_time=0.1,
make_gif=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=500, max_epochs=10,
callbacks=[ callbacks=[
vis, dvis,
], ],
) )

View File

@ -4,16 +4,11 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.models.cbc import CBC
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
class NumpyDataset(TensorDataset): from prototorch.models.cbc import CBC
def __init__(self, *arrays):
# tensors = [torch.from_numpy(arr) for arr in arrays]
tensors = [torch.Tensor(arr) for arr in arrays]
super().__init__(*tensors)
class VisualizationCallback(pl.Callback): class VisualizationCallback(pl.Callback):
@ -47,7 +42,8 @@ class VisualizationCallback(pl.Callback):
cmap=self.cmap, cmap=self.cmap,
edgecolor="k", edgecolor="k",
marker="D", marker="D",
s=50) s=50,
)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
@ -73,11 +69,13 @@ if __name__ == "__main__":
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters # Hyperparameters
hparams = dict(input_dim=x_train.shape[1], hparams = dict(
nclasses=3, input_dim=x_train.shape[1],
prototypes_per_class=3, nclasses=3,
prototype_initializer="stratified_mean", prototypes_per_class=3,
lr=0.01) prototype_initializer="stratified_mean",
lr=0.01,
)
# Initialize the model # Initialize the model
model = CBC(hparams, data=[x_train, y_train]) model = CBC(hparams, data=[x_train, y_train])

View File

@ -7,12 +7,12 @@ import argparse
import pytorch_lightning as pl import pytorch_lightning as pl
import torchvision import torchvision
from matplotlib import pyplot as plt
from prototorch.models.cbc import ImageCBC, euclidean_similarity, rescaled_cosine_similarity
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from prototorch.models.cbc import CBC, ImageCBC, euclidean_similarity
class VisualizationCallback(pl.Callback): class VisualizationCallback(pl.Callback):
def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2): def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2):
@ -89,8 +89,8 @@ if __name__ == "__main__":
) )
# Dataloaders # Dataloaders
train_loader = DataLoader(mnist_train, batch_size=1024) train_loader = DataLoader(mnist_train, batch_size=32)
test_loader = DataLoader(mnist_test, batch_size=1024) test_loader = DataLoader(mnist_test, batch_size=32)
# Grab the full dataset to warm-start prototypes # Grab the full dataset to warm-start prototypes
x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train)))) x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
@ -102,12 +102,12 @@ if __name__ == "__main__":
nclasses=10, nclasses=10,
prototypes_per_class=args.ppc, prototypes_per_class=args.ppc,
prototype_initializer="randn", prototype_initializer="randn",
lr=1, lr=0.01,
similarity=euclidean_similarity, similarity=euclidean_similarity,
) )
# Initialize the model # Initialize the model
model = ImageCBC(hparams, data=[x, y]) model = CBC(hparams, data=[x, y])
# Model summary # Model summary
print(model) print(model)

135
examples/cbc_spiral.py Normal file
View File

@ -0,0 +1,135 @@
"""CBC example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.cbc import CBC
class VisualizationCallback(pl.Callback):
def __init__(
self,
x_train,
y_train,
prototype_model=True,
title="Prototype Visualization",
cmap="viridis",
):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
self.prototype_model = prototype_model
def on_epoch_end(self, trainer, pl_module):
if self.prototype_model:
protos = pl_module.prototypes
color = pl_module.prototype_labels
else:
protos = pl_module.components
color = "k"
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=color,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
def make_spirals(n_samples=500, noise=0.3):
def get_samples(n, delta_t):
points = []
for i in range(n):
r = i / n_samples * 5
t = 1.75 * i / n * 2 * np.pi + delta_t
x = r * np.sin(t) + np.random.rand(1) * noise
y = r * np.cos(t) + np.random.rand(1) * noise
points.append([x, y])
return points
n = n_samples // 2
positive = get_samples(n=n, delta_t=0)
negative = get_samples(n=n, delta_t=np.pi)
x = np.concatenate(
[np.array(positive).reshape(n, -1),
np.array(negative).reshape(n, -1)],
axis=0)
y = np.concatenate([np.zeros(n), np.ones(n)])
return x, y
if __name__ == "__main__":
# Dataset
x_train, y_train = make_spirals(n_samples=1000, noise=0.3)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=2,
prototypes_per_class=40,
prototype_initializer="stratified_random",
lr=0.05,
)
# Initialize the model
model_class = CBC
model = model_class(hparams, data=[x_train, y_train])
# Pure-positive reasonings
new_reasoning = torch.zeros_like(
model.reasoning_layer.reasoning_probabilities)
for i, label in enumerate(model.proto_layer.prototype_labels):
new_reasoning[0][0][i][int(label)] = 1.0
model.reasoning_layer.reasoning_probabilities.data = new_reasoning
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(x_train,
y_train,
prototype_model=hasattr(model, "prototypes"))
# Setup trainer
trainer = pl.Trainer(
max_epochs=500,
callbacks=[
vis,
],
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -0,0 +1,142 @@
"""CBC example using the Iris dataset."""
import numpy as np
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
from prototorch.models.cbc import CBC
from prototorch.models.glvq import GLVQ
class VisualizationCallback(pl.Callback):
def __init__(
self,
x_train,
y_train,
prototype_model=True,
title="Prototype Visualization",
cmap="viridis",
):
super().__init__()
self.x_train = x_train
self.y_train = y_train
self.title = title
self.fig = plt.figure(self.title)
self.cmap = cmap
self.prototype_model = prototype_model
def on_epoch_end(self, trainer, pl_module):
if self.prototype_model:
protos = pl_module.prototypes
color = pl_module.prototype_labels
else:
protos = pl_module.components
color = "k"
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=color,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
np.arange(y_min, y_max, 1 / 50))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1)
def make_spirals(n_samples=500, noise=0.3):
def get_samples(n, delta_t):
points = []
for i in range(n):
r = i / n_samples * 5
t = 1.75 * i / n * 2 * np.pi + delta_t
x = r * np.sin(t) + np.random.rand(1) * noise
y = r * np.cos(t) + np.random.rand(1) * noise
points.append([x, y])
return points
n = n_samples // 2
positive = get_samples(n=n, delta_t=0)
negative = get_samples(n=n, delta_t=np.pi)
x = np.concatenate(
[np.array(positive).reshape(n, -1),
np.array(negative).reshape(n, -1)],
axis=0)
y = np.concatenate([np.zeros(n), np.ones(n)])
return x, y
def train(model, x_train, y_train, train_loader, epochs=100):
# Callbacks
vis = VisualizationCallback(x_train,
y_train,
prototype_model=hasattr(model, "prototypes"))
# Setup trainer
trainer = pl.Trainer(
max_epochs=epochs,
callbacks=[
vis,
],
)
# Training loop
trainer.fit(model, train_loader)
if __name__ == "__main__":
# Dataset
x_train, y_train = make_spirals(n_samples=1000, noise=0.3)
train_ds = NumpyDataset(x_train, y_train)
# Dataloaders
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=2,
prototypes_per_class=40,
prototype_initializer="stratified_random",
lr=0.05,
)
# Initialize the model
glvq_model = GLVQ(hparams, data=[x_train, y_train])
cbc_model = CBC(hparams, data=[x_train, y_train])
# Train GLVQ
train(glvq_model, x_train, y_train, train_loader, epochs=10)
# Transfer Prototypes
cbc_model.proto_layer.load_state_dict(glvq_model.proto_layer.state_dict())
# Pure-positive reasonings
new_reasoning = torch.zeros_like(
cbc_model.reasoning_layer.reasoning_probabilities)
for i, label in enumerate(cbc_model.proto_layer.prototype_labels):
new_reasoning[0][0][i][int(label)] = 1.0
new_reasoning[1][0][i][1 - int(label)] = 1.0
cbc_model.reasoning_layer.reasoning_probabilities.data = new_reasoning
# Train CBC
train(cbc_model, x_train, y_train, train_loader, epochs=50)

View File

@ -6,15 +6,11 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.models.glvq import GLVQ
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
class NumpyDataset(TensorDataset): from prototorch.models.glvq import GLVQ
def __init__(self, *arrays):
tensors = [torch.from_numpy(arr) for arr in arrays]
super().__init__(*tensors)
class GLVQIris(GLVQ): class GLVQIris(GLVQ):
@ -56,13 +52,15 @@ class VisualizationCallback(pl.Callback):
ax.set_xlabel("Data dimension 1") ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2") ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(protos[:, 0], ax.scatter(
protos[:, 1], protos[:, 0],
c=plabels, protos[:, 1],
cmap=self.cmap, c=plabels,
edgecolor="k", cmap=self.cmap,
marker="D", edgecolor="k",
s=50) marker="D",
s=50,
)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
@ -105,8 +103,8 @@ if __name__ == "__main__":
parser, parser,
max_epochs=10, max_epochs=10,
callbacks=[ callbacks=[
vis, # comment this line out to disable the visualization vis,
], ], # comment this line out to disable the visualization
) )
# trainer.tune(model) # trainer.tune(model)

View File

@ -4,15 +4,11 @@ import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.models.glvq import GLVQ
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from torch.utils.data import DataLoader, TensorDataset from torch.utils.data import DataLoader
from prototorch.datasets.abstract import NumpyDataset
class NumpyDataset(TensorDataset): from prototorch.models.glvq import GLVQ
def __init__(self, *arrays):
tensors = [torch.from_numpy(arr) for arr in arrays]
super().__init__(*tensors)
class VisualizationCallback(pl.Callback): class VisualizationCallback(pl.Callback):
@ -37,13 +33,15 @@ class VisualizationCallback(pl.Callback):
ax.set_xlabel("Data dimension 1") ax.set_xlabel("Data dimension 1")
ax.set_ylabel("Data dimension 2") ax.set_ylabel("Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(protos[:, 0], ax.scatter(
protos[:, 1], protos[:, 0],
c=plabels, protos[:, 1],
cmap=self.cmap, c=plabels,
edgecolor="k", cmap=self.cmap,
marker="D", edgecolor="k",
s=50) marker="D",
s=50,
)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1 x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1 y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
@ -69,11 +67,13 @@ if __name__ == "__main__":
train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) train_loader = DataLoader(train_ds, num_workers=0, batch_size=150)
# Hyperparameters # Hyperparameters
hparams = dict(input_dim=x_train.shape[1], hparams = dict(
nclasses=3, input_dim=x_train.shape[1],
prototypes_per_class=3, nclasses=3,
prototype_initializer="stratified_mean", prototypes_per_class=3,
lr=0.1) prototype_initializer="stratified_mean",
lr=0.1,
)
# Initialize the model # Initialize the model
model = GLVQ(hparams, data=[x_train, y_train]) model = GLVQ(hparams, data=[x_train, y_train])

View File

@ -11,13 +11,12 @@ import argparse
import pytorch_lightning as pl import pytorch_lightning as pl
import torchvision import torchvision
from matplotlib import pyplot as plt
from prototorch.functions.initializers import stratified_mean
from prototorch.models.glvq import ImageGLVQ
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from prototorch.models.glvq import ImageGLVQ
class VisualizationCallback(pl.Callback): class VisualizationCallback(pl.Callback):
def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2): def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2):
@ -31,10 +30,12 @@ class VisualizationCallback(pl.Callback):
grid = torchvision.utils.make_grid(protos_img, nrow=self.nrow) grid = torchvision.utils.make_grid(protos_img, nrow=self.nrow)
# grid = grid.permute((1, 2, 0)) # grid = grid.permute((1, 2, 0))
tb = pl_module.logger.experiment tb = pl_module.logger.experiment
tb.add_image(tag="MNIST Prototypes", tb.add_image(
img_tensor=grid, tag="MNIST Prototypes",
global_step=trainer.current_epoch, img_tensor=grid,
dataformats="CHW") global_step=trainer.current_epoch,
dataformats="CHW",
)
if __name__ == "__main__": if __name__ == "__main__":
@ -91,11 +92,13 @@ if __name__ == "__main__":
x = x.view(len(mnist_train), -1) x = x.view(len(mnist_train), -1)
# Initialize the model # Initialize the model
model = ImageGLVQ(input_dim=28 * 28, model = ImageGLVQ(
nclasses=10, input_dim=28 * 28,
prototypes_per_class=args.ppc, nclasses=10,
prototype_initializer="stratified_mean", prototypes_per_class=args.ppc,
data=[x, y]) prototype_initializer="stratified_mean",
data=[x, y],
)
# Model summary # Model summary
print(model) print(model)

View File

@ -1,8 +1,8 @@
from importlib.metadata import version, PackageNotFoundError from importlib.metadata import PackageNotFoundError, version
VERSION_FALLBACK = "uninstalled_version" VERSION_FALLBACK = "uninstalled_version"
try: try:
__version__ = version(__name__.replace(".", "-")) __version__ = version(__name__.replace(".", "-"))
except PackageNotFoundError: except PackageNotFoundError:
__version__ = VERSION_FALLBACK __version__ = VERSION_FALLBACK
pass pass

View File

@ -1,13 +1,9 @@
import argparse
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity from prototorch.functions.similarities import cosine_similarity
from prototorch.functions.initializers import get_initializer
from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D
@ -64,9 +60,6 @@ class ReasoningLayer(torch.nn.Module):
super().__init__() super().__init__()
self.n_replicas = n_replicas self.n_replicas = n_replicas
self.n_classes = n_classes self.n_classes = n_classes
# probabilities_init = torch.zeros(2, self.n_replicas, n_components,
# self.n_classes)
# probabilities_init = torch.zeros(2, n_components, self.n_classes)
probabilities_init = torch.zeros(2, 1, n_components, self.n_classes) probabilities_init = torch.zeros(2, 1, n_components, self.n_classes)
probabilities_init.uniform_(0.4, 0.6) probabilities_init.uniform_(0.4, 0.6)
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init) self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@ -75,37 +68,28 @@ class ReasoningLayer(torch.nn.Module):
def reasonings(self): def reasonings(self):
pk = self.reasoning_probabilities[0] pk = self.reasoning_probabilities[0]
nk = (1 - pk) * self.reasoning_probabilities[1] nk = (1 - pk) * self.reasoning_probabilities[1]
ik = (1 - pk - nk) ik = 1 - pk - nk
# pk is of shape (1, n_components, n_classes)
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2) img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1) # (n_components, 1, 3, n_classes) return img.unsqueeze(1)
def forward(self, detections): def forward(self, detections):
pk = self.reasoning_probabilities[0].clamp(0, 1) pk = self.reasoning_probabilities[0].clamp(0, 1)
nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1) nk = (1 - pk) * self.reasoning_probabilities[1].clamp(0, 1)
epsilon = torch.finfo(pk.dtype).eps epsilon = torch.finfo(pk.dtype).eps
# print(f"{detections.shape=}")
# print(f"{pk.shape=}")
# print(f"{detections.min()=}")
# print(f"{detections.max()=}")
numerator = (detections @ (pk - nk)) + nk.sum(1) numerator = (detections @ (pk - nk)) + nk.sum(1)
# probs = numerator / (pk + nk).sum(1).clamp(min=epsilon)
probs = numerator / (pk + nk).sum(1) probs = numerator / (pk + nk).sum(1)
# probs = probs.squeeze(0)
probs = probs.squeeze(0) probs = probs.squeeze(0)
return probs return probs
class CBC(pl.LightningModule): class CBC(pl.LightningModule):
"""Classification-By-Components.""" """Classification-By-Components."""
def __init__( def __init__(self,
self, hparams,
hparams, margin=0.1,
margin=0.1, backbone_class=torch.nn.Identity,
backbone_class=torch.nn.Identity, similarity=euclidean_similarity,
# similarity=rescaled_cosine_similarity, **kwargs):
similarity=euclidean_similarity,
**kwargs):
super().__init__() super().__init__()
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
self.margin = margin self.margin = margin
@ -142,15 +126,11 @@ class CBC(pl.LightningModule):
def forward(self, x): def forward(self, x):
self.sync_backbones() self.sync_backbones()
protos = self.proto_layer.prototypes protos, _ = self.proto_layer()
# protos, _ = self.proto_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos) latent_protos = self.backbone_dependent(protos)
# print(f"{latent_x.dtype=}")
# print(f"{latent_protos.dtype=}")
detections = self.similarity(latent_x, latent_protos) detections = self.similarity(latent_x, latent_protos)
probs = self.reasoning_layer(detections) probs = self.reasoning_layer(detections)
return probs return probs
@ -159,20 +139,10 @@ class CBC(pl.LightningModule):
x, y = train_batch x, y = train_batch
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
y_pred = self(x) y_pred = self(x)
# print(f"{y_pred.min()=}")
# print(f"{y_pred.max()=}")
nclasses = self.reasoning_layer.n_classes nclasses = self.reasoning_layer.n_classes
# y_true = torch.nn.functional.one_hot(y, num_classes=nclasses)
# y_true = torch.eye(nclasses)[y.long()]
y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses) y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0) loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
self.log("train_loss", loss) self.log("train_loss", loss)
# with torch.no_grad():
# preds = torch.argmax(y_pred, dim=1)
# # self.train_acc.update(preds.int(), y.int())
# self.train_acc(
# preds.int(),
# y.int()) # FloatTensors are assumed to be class probabilities
self.train_acc(y_pred, y_true) self.train_acc(y_pred, y_true)
self.log( self.log(
"acc", "acc",
@ -184,17 +154,8 @@ class CBC(pl.LightningModule):
) )
return loss return loss
#def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
# self.reasoning_layer.reasoning_probabilities.data.clamp_(0., 1.)
# def training_epoch_end(self, outs):
# # Calling `self.train_acc.compute()` is
# # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)`
# self.log("train_acc_epoch", self.train_acc.compute())
def predict(self, x): def predict(self, x):
with torch.no_grad(): with torch.no_grad():
# model.eval() # ?!
y_pred = self(x) y_pred = self(x)
y_pred = torch.argmax(y_pred, dim=1) y_pred = torch.argmax(y_pred, dim=1)
return y_pred.numpy() return y_pred.numpy()
@ -205,5 +166,5 @@ class ImageCBC(CBC):
clamping after updates. clamping after updates.
""" """
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
#super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) # super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
self.proto_layer.prototypes.data.clamp_(0.0, 1.0) self.proto_layer.prototypes.data.clamp_(0.0, 1.0)

View File

@ -1,11 +1,9 @@
import argparse
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.functions.initializers import get_initializer
from prototorch.functions.losses import glvq_loss from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D from prototorch.modules.prototypes import Prototypes1D
@ -54,12 +52,14 @@ class GLVQ(pl.LightningModule):
self.train_acc( self.train_acc(
preds.int(), preds.int(),
y.int()) # FloatTensors are assumed to be class probabilities y.int()) # FloatTensors are assumed to be class probabilities
self.log("acc", self.log(
self.train_acc, "acc",
on_step=False, self.train_acc,
on_epoch=True, on_step=False,
prog_bar=True, on_epoch=True,
logger=True) prog_bar=True,
logger=True,
)
return loss return loss
# def training_epoch_end(self, outs): # def training_epoch_end(self, outs):
@ -81,4 +81,4 @@ class ImageGLVQ(GLVQ):
clamping after updates. clamping after updates.
""" """
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.prototypes.data.clamp_(0., 1.) self.proto_layer.prototypes.data.clamp_(0.0, 1.0)

View File

@ -9,8 +9,7 @@
ProtoTorch models Plugin Package ProtoTorch models Plugin Package
""" """
from pkg_resources import safe_name from pkg_resources import safe_name
from setuptools import setup from setuptools import find_namespace_packages, setup
from setuptools import find_namespace_packages
PLUGIN_NAME = "models" PLUGIN_NAME = "models"
@ -28,7 +27,8 @@ ALL = EXAMPLES + TESTS
setup( setup(
name=safe_name("prototorch_" + PLUGIN_NAME), name=safe_name("prototorch_" + PLUGIN_NAME),
use_scm_version=True, use_scm_version=True,
descripion="Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning.", descripion=
"Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description, long_description=long_description,
author="Alexander Engelsberger", author="Alexander Engelsberger",
author_email="engelsbe@hs-mittweida.de", author_email="engelsbe@hs-mittweida.de",