Merge pull request #5 from si-cim/dev

Dev
This commit is contained in:
Alexander Engelsberger 2021-05-10 15:35:49 +02:00 committed by GitHub
commit 3d02aef755
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 144 additions and 71 deletions

11
.bumpversion.cfg Normal file
View File

@ -0,0 +1,11 @@
[bumpversion]
current_version = 0.0.0
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize =
{major}.{minor}.{patch}
[bumpversion:file:setup.py]
[bumpversion:file:./prototorch/models/__init__.py]

15
.codacy.yml Normal file
View File

@ -0,0 +1,15 @@
# To validate the contents of your configuration file
# run the following command in the folder where the configuration file is located:
# codacy-analysis-cli validate-configuration --directory `pwd`
# To analyse, run:
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
---
engines:
pylintpython3:
exclude_paths:
- config/engines.yml
remark-lint:
exclude_paths:
- config/engines.yml
exclude_paths:
- 'tests/**'

2
.codecov.yml Normal file
View File

@ -0,0 +1,2 @@
comment:
require_changes: yes

35
.travis.yml Normal file
View File

@ -0,0 +1,35 @@
dist: bionic
sudo: false
language: python
python: 3.8
cache:
directories:
- "./tests/artifacts"
# - "$HOME/.prototorch/datasets"
install:
- pip install .[all] --progress-bar off
# Generate code coverage report
script:
- coverage run -m pytest
# Push the results to codecov
after_success:
- bash <(curl -s https://codecov.io/bash)
# Publish on PyPI
deploy:
provider: pypi
username: __token__
password:
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
on:
tags: true
skip_existing: true
# The password is encrypted with:
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
# See https://docs.travis-ci.com/user/deployment/pypi and
# https://github.com/travis-ci/travis.rb#installation
# for more details
# Note: The encrypt command does not work well in ZSH.

View File

@ -55,6 +55,7 @@ To assist in the development process, you may also find it useful to install
## Available models
- Generalized Learning Vector Quantization (GLVQ)
- Generalized Relevance Learning Vector Quantization (GRLVQ)
- Generalized Matrix Learning Vector Quantization (GMLVQ)
- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
- Siamese GLVQ

View File

@ -29,17 +29,20 @@ if __name__ == "__main__":
# Initialize the model
model = pt.models.GMLVQ(hparams)
# Model summary
print(model)
# Callbacks
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
# Namespace hook for the visualization to work
model.backbone = model.omega_layer
# Setup trainer
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
# Training loop
trainer.fit(model, train_loader)
# Save the model
torch.save(model, "liramlvq_tecator.pt")
# Load a saved model
saved_model = torch.load("liramlvq_tecator.pt")
# Display the Lambda matrix
saved_model.show_lambda()

View File

@ -5,9 +5,4 @@ from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ
from .neural_gas import NeuralGas
from .vis import *
VERSION_FALLBACK = "uninstalled_version"
try:
__version__ = version(__name__.replace(".", "-"))
except PackageNotFoundError:
__version__ = VERSION_FALLBACK
pass
__version__ = "0.0.0"

View File

@ -191,14 +191,17 @@ class GMLVQ(GLVQ):
self.hparams.latent_dim,
bias=False)
# Namespace hook for the visualization callbacks to work
self.backbone = self.omega_layer
@property
def omega_matrix(self):
return self.omega_layer.weight.detach().cpu()
@property
def lambda_matrix(self):
omega = self.omega_layer.weight
lam = omega @ omega.T
omega = self.omega_layer.weight # (latent_dim, input_dim)
lam = omega.T @ omega
return lam.detach().cpu()
def show_lambda(self):
@ -250,6 +253,9 @@ class LVQMLN(GLVQ):
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
with torch.no_grad():
protos = self.backbone(self.proto_layer()[0])
self.proto_layer.load_state_dict({"_components": protos}, strict=False)
def forward(self, x):
latent_protos, _ = self.proto_layer()

View File

@ -269,6 +269,7 @@ class Vis2DAbstract(pl.Callback):
cmap="viridis",
border=1,
resolution=50,
show_protos=True,
tensorboard=False,
show_last_only=False,
pause_time=0.1,
@ -288,11 +289,17 @@ class Vis2DAbstract(pl.Callback):
self.cmap = cmap
self.border = border
self.resolution = resolution
self.show_protos = show_protos
self.tensorboard = tensorboard
self.show_last_only = show_last_only
self.pause_time = pause_time
self.block = block
def precheck(self, trainer):
if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1:
return
def setup_ax(self, xlabel=None, ylabel=None):
ax = self.fig.gca()
ax.cla()
@ -312,6 +319,28 @@ class Vis2DAbstract(pl.Callback):
mesh_input = np.c_[xx.ravel(), yy.ravel()]
return mesh_input, xx, yy
def plot_data(self, ax, x, y):
ax.scatter(
x[:, 0],
x[:, 1],
c=y,
cmap=self.cmap,
edgecolor="k",
marker="o",
s=30,
)
def plot_protos(self, ax, protos, plabels):
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
tb.add_figure(tag=f"{self.title}",
@ -327,118 +356,95 @@ class Vis2DAbstract(pl.Callback):
else:
plt.show(block=True)
def on_train_end(self, trainer, pl_module):
plt.show()
class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1:
return
self.precheck(trainer)
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
# ax.set_xlim(left=x_min + 0, right=x_max - 0)
# ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
self.log_and_display(trainer, pl_module)
class VisSiameseGLVQ2D(Vis2DAbstract):
def __init__(self, *args, map_protos=True, **kwargs):
super().__init__(*args, **kwargs)
self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer)
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
if self.map_protos:
protos = pl_module.backbone(torch.Tensor(protos)).detach()
ax = self.setup_ax()
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
else:
mesh_input, xx, yy = self.get_mesh_input(x_train)
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
# ax.set_xlim(left=x_min + 0, right=x_max - 0)
# ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
self.log_and_display(trainer, pl_module)
class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer)
x_train, y_train = self.x_train, self.y_train
protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c="w",
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
# ax.set_xlim(left=x_min + 0, right=x_max - 0)
# ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
self.log_and_display(trainer, pl_module)
class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer)
x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
ax.scatter(
protos[:, 0],
protos[:, 1],
c="k",
edgecolor="k",
marker="D",
s=50,
)
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w")
# Draw connections
for i in range(len(protos)):
for j in range(len(protos)):
for j in range(i, len(protos)):
if cmat[i][j]:
ax.plot(
[protos[i, 0], protos[j, 0]],

View File

@ -21,12 +21,12 @@ with open("README.md", "r") as fh:
INSTALL_REQUIRES = ["prototorch", "pytorch_lightning", "torchmetrics"]
EXAMPLES = ["matplotlib", "scikit-learn"]
TESTS = ["pytest"]
TESTS = ["codecov", "pytest"]
ALL = EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
use_scm_version=True,
version="0.0.0",
descripion=
"Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,
@ -36,7 +36,6 @@ setup(
download_url=DOWNLOAD_URL,
license="MIT",
install_requires=INSTALL_REQUIRES,
setup_requires=["setuptools_scm"],
extras_require={
"examples": EXAMPLES,
"tests": TESTS,