Compare commits
10 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
1498c4bde5 | ||
|
59b8ab6643 | ||
|
2a4f184163 | ||
|
265e74dd31 | ||
|
daad018a78 | ||
|
eab1ec72c2 | ||
|
b38acd58a8 | ||
|
e87563e10d | ||
|
767206f905 | ||
|
3fa6378c4d |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.1.5
|
current_version = 0.1.7
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
|
23
README.md
23
README.md
@@ -8,24 +8,18 @@ PyTorch-Lightning.
|
|||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
To install this plugin, first install
|
To install this plugin, simply run the following command:
|
||||||
[ProtoTorch](https://github.com/si-cim/prototorch) with:
|
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
git clone https://github.com/si-cim/prototorch.git && cd prototorch
|
pip install prototorch_models
|
||||||
pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
and then install the plugin itself with:
|
|
||||||
|
|
||||||
```sh
|
|
||||||
git clone https://github.com/si-cim/prototorch_models.git && cd prototorch_models
|
|
||||||
pip install -e .
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The plugin should then be available for use in your Python environment as
|
The plugin should then be available for use in your Python environment as
|
||||||
`prototorch.models`.
|
`prototorch.models`.
|
||||||
|
|
||||||
|
*Note: Installing the models plugin should automatically install a suitable
|
||||||
|
version of * [ProtoTorch](https://github.com/si-cim/prototorch).
|
||||||
|
|
||||||
## Development setup
|
## Development setup
|
||||||
|
|
||||||
It is recommended that you use a virtual environment for development. If you do
|
It is recommended that you use a virtual environment for development. If you do
|
||||||
@@ -57,17 +51,20 @@ To assist in the development process, you may also find it useful to install
|
|||||||
|
|
||||||
## Available models
|
## Available models
|
||||||
|
|
||||||
|
- K-Nearest Neighbors (KNN)
|
||||||
|
- Learning Vector Quantization 1 (LVQ1)
|
||||||
- Generalized Learning Vector Quantization (GLVQ)
|
- Generalized Learning Vector Quantization (GLVQ)
|
||||||
- Generalized Relevance Learning Vector Quantization (GRLVQ)
|
- Generalized Relevance Learning Vector Quantization (GRLVQ)
|
||||||
- Generalized Matrix Learning Vector Quantization (GMLVQ)
|
- Generalized Matrix Learning Vector Quantization (GMLVQ)
|
||||||
- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
|
- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
|
||||||
|
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
|
||||||
- Siamese GLVQ
|
- Siamese GLVQ
|
||||||
- Neural Gas (NG)
|
- Neural Gas (NG)
|
||||||
|
|
||||||
## Work in Progress
|
## Work in Progress
|
||||||
|
|
||||||
- Classification-By-Components Network (CBC)
|
- Classification-By-Components Network (CBC)
|
||||||
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
|
- Learning Vector Quantization 2.1 (LVQ2.1)
|
||||||
|
|
||||||
## Planned models
|
## Planned models
|
||||||
|
|
||||||
@@ -76,8 +73,6 @@ To assist in the development process, you may also find it useful to install
|
|||||||
- Robust Soft Learning Vector Quantization (RSLVQ)
|
- Robust Soft Learning Vector Quantization (RSLVQ)
|
||||||
- Probabilistic Learning Vector Quantization (PLVQ)
|
- Probabilistic Learning Vector Quantization (PLVQ)
|
||||||
- Self-Incremental Learning Vector Quantization (SILVQ)
|
- Self-Incremental Learning Vector Quantization (SILVQ)
|
||||||
- K-Nearest Neighbors (KNN)
|
|
||||||
- Learning Vector Quantization 1 (LVQ1)
|
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
|
@@ -17,15 +17,16 @@ if __name__ == "__main__":
|
|||||||
batch_size=150)
|
batch_size=150)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
|
nclasses = 3
|
||||||
|
prototypes_per_class = 2
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
nclasses=3,
|
distribution=(nclasses, prototypes_per_class),
|
||||||
prototypes_per_class=2,
|
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
prototype_initializer=pt.components.SMI(train_ds),
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GLVQ(hparams)
|
model = pt.models.GLVQ(hparams, optimizer=torch.optim.Adam)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||||
|
@@ -25,10 +25,11 @@ if __name__ == "__main__":
|
|||||||
batch_size=256)
|
batch_size=256)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
|
nclasses = 2
|
||||||
|
prototypes_per_class = 20
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
nclasses=2,
|
distribution=(nclasses, prototypes_per_class),
|
||||||
prototypes_per_class=20,
|
prototype_initializer=pt.components.SSI(train_ds, noise=1e-1),
|
||||||
prototype_initializer=pt.components.SSI(train_ds, noise=1e-7),
|
|
||||||
transfer_function="sigmoid_beta",
|
transfer_function="sigmoid_beta",
|
||||||
transfer_beta=10.0,
|
transfer_beta=10.0,
|
||||||
lr=0.01,
|
lr=0.01,
|
||||||
|
@@ -15,9 +15,10 @@ if __name__ == "__main__":
|
|||||||
num_workers=0,
|
num_workers=0,
|
||||||
batch_size=150)
|
batch_size=150)
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
|
nclasses = 3
|
||||||
|
prototypes_per_class = 1
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
nclasses=3,
|
distribution=(nclasses, prototypes_per_class),
|
||||||
prototypes_per_class=1,
|
|
||||||
input_dim=x_train.shape[1],
|
input_dim=x_train.shape[1],
|
||||||
latent_dim=x_train.shape[1],
|
latent_dim=x_train.shape[1],
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
prototype_initializer=pt.components.SMI(train_ds),
|
||||||
|
37
examples/knn_iris.py
Normal file
37
examples/knn_iris.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""k-NN example using the Iris dataset."""
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Dataset
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
x_train, y_train = load_iris(return_X_y=True)
|
||||||
|
x_train = x_train[:, [0, 2]]
|
||||||
|
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=150)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(k=20)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.KNN(hparams, data=train_ds)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer(max_epochs=1, callbacks=[vis])
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
# This is only for visualization. k-NN has no training phase.
|
||||||
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
# Recall
|
||||||
|
y_pred = model.predict(torch.tensor(x_train))
|
||||||
|
print(y_pred)
|
@@ -17,9 +17,10 @@ if __name__ == "__main__":
|
|||||||
batch_size=32)
|
batch_size=32)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
|
nclasses = 2
|
||||||
|
prototypes_per_class = 2
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
nclasses=2,
|
distribution=(nclasses, prototypes_per_class),
|
||||||
prototypes_per_class=2,
|
|
||||||
input_dim=100,
|
input_dim=100,
|
||||||
latent_dim=2,
|
latent_dim=2,
|
||||||
prototype_initializer=pt.components.SMI(train_ds),
|
prototype_initializer=pt.components.SMI(train_ds),
|
||||||
|
@@ -24,9 +24,7 @@ class Backbone(torch.nn.Module):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Dataset
|
# Dataset
|
||||||
from sklearn.datasets import load_iris
|
train_ds = pt.datasets.Iris()
|
||||||
x_train, y_train = load_iris(return_X_y=True)
|
|
||||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
pl.utilities.seed.seed_everything(seed=2)
|
pl.utilities.seed.seed_everything(seed=2)
|
||||||
@@ -38,11 +36,10 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
nclasses=3,
|
distribution=[1, 2, 3],
|
||||||
prototypes_per_class=2,
|
prototype_initializer=pt.components.SMI(train_ds),
|
||||||
prototype_initializer=pt.components.SMI((x_train, y_train)),
|
proto_lr=0.01,
|
||||||
proto_lr=0.001,
|
bb_lr=0.01,
|
||||||
bb_lr=0.001,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
@@ -55,7 +52,7 @@ if __name__ == "__main__":
|
|||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisSiameseGLVQ2D(data=(x_train, y_train), border=0.1)
|
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
|
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
|
||||||
|
@@ -1,8 +1,10 @@
|
|||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
from .cbc import CBC
|
from .cbc import CBC
|
||||||
from .glvq import GLVQ, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, SiameseGLVQ
|
from .glvq import (GLVQ, GMLVQ, GRLVQ, LVQ1, LVQ21, LVQMLN, ImageGLVQ,
|
||||||
|
SiameseGLVQ)
|
||||||
|
from .knn import KNN
|
||||||
from .neural_gas import NeuralGas
|
from .neural_gas import NeuralGas
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
__version__ = "0.1.5"
|
__version__ = "0.1.7"
|
||||||
|
@@ -3,9 +3,13 @@ import torch
|
|||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
|
|
||||||
class AbstractLightningModel(pl.LightningModule):
|
class AbstractPrototypeModel(pl.LightningModule):
|
||||||
|
@property
|
||||||
|
def prototypes(self):
|
||||||
|
return self.proto_layer.components.detach().cpu()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
optimizer = self.optimizer(self.parameters(), lr=self.hparams.lr)
|
||||||
scheduler = ExponentialLR(optimizer,
|
scheduler = ExponentialLR(optimizer,
|
||||||
gamma=0.99,
|
gamma=0.99,
|
||||||
last_epoch=-1,
|
last_epoch=-1,
|
||||||
@@ -15,9 +19,3 @@ class AbstractLightningModel(pl.LightningModule):
|
|||||||
"interval": "step",
|
"interval": "step",
|
||||||
} # called after each training step
|
} # called after each training step
|
||||||
return [optimizer], [sch]
|
return [optimizer], [sch]
|
||||||
|
|
||||||
|
|
||||||
class AbstractPrototypeModel(AbstractLightningModel):
|
|
||||||
@property
|
|
||||||
def prototypes(self):
|
|
||||||
return self.proto_layer.components.detach().cpu()
|
|
||||||
|
@@ -5,7 +5,7 @@ from prototorch.functions.activations import get_activation
|
|||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
||||||
squared_euclidean_distance)
|
squared_euclidean_distance)
|
||||||
from prototorch.functions.losses import glvq_loss
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
|
|
||||||
from .abstract import AbstractPrototypeModel
|
from .abstract import AbstractPrototypeModel
|
||||||
|
|
||||||
@@ -17,19 +17,22 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
self.save_hyperparameters(hparams)
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
|
|
||||||
# Default Values
|
# Default Values
|
||||||
self.hparams.setdefault("distance", euclidean_distance)
|
self.hparams.setdefault("distance", euclidean_distance)
|
||||||
self.hparams.setdefault("optimizer", torch.optim.Adam)
|
|
||||||
self.hparams.setdefault("transfer_function", "identity")
|
self.hparams.setdefault("transfer_function", "identity")
|
||||||
self.hparams.setdefault("transfer_beta", 10.0)
|
self.hparams.setdefault("transfer_beta", 10.0)
|
||||||
|
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
|
distribution=self.hparams.distribution,
|
||||||
initializer=self.hparams.prototype_initializer)
|
initializer=self.hparams.prototype_initializer)
|
||||||
|
|
||||||
self.transfer_function = get_activation(self.hparams.transfer_function)
|
self.transfer_function = get_activation(self.hparams.transfer_function)
|
||||||
self.train_acc = torchmetrics.Accuracy()
|
self.train_acc = torchmetrics.Accuracy()
|
||||||
|
|
||||||
|
self.loss = glvq_loss
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prototype_labels(self):
|
def prototype_labels(self):
|
||||||
return self.proto_layer.component_labels.detach().cpu()
|
return self.proto_layer.component_labels.detach().cpu()
|
||||||
@@ -44,7 +47,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
x = x.view(x.size(0), -1) # flatten
|
x = x.view(x.size(0), -1) # flatten
|
||||||
dis = self(x)
|
dis = self(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
mu = glvq_loss(dis, y, prototype_labels=plabels)
|
mu = self.loss(dis, y, prototype_labels=plabels)
|
||||||
batch_loss = self.transfer_function(mu,
|
batch_loss = self.transfer_function(mu,
|
||||||
beta=self.hparams.transfer_beta)
|
beta=self.hparams.transfer_beta)
|
||||||
loss = batch_loss.sum(dim=0)
|
loss = batch_loss.sum(dim=0)
|
||||||
@@ -76,6 +79,22 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
return y_pred.numpy()
|
return y_pred.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
class LVQ1(GLVQ):
|
||||||
|
"""Learning Vector Quantization 1."""
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
self.loss = lvq1_loss
|
||||||
|
self.optimizer = torch.optim.SGD
|
||||||
|
|
||||||
|
|
||||||
|
class LVQ21(GLVQ):
|
||||||
|
"""Learning Vector Quantization 2.1."""
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
self.loss = lvq21_loss
|
||||||
|
self.optimizer = torch.optim.SGD
|
||||||
|
|
||||||
|
|
||||||
class ImageGLVQ(GLVQ):
|
class ImageGLVQ(GLVQ):
|
||||||
"""GLVQ for training on image data.
|
"""GLVQ for training on image data.
|
||||||
|
|
||||||
@@ -112,13 +131,13 @@ class SiameseGLVQ(GLVQ):
|
|||||||
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optim = self.hparams.optimizer
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||||
proto_opt = optim(self.proto_layer.parameters(),
|
lr=self.hparams.proto_lr)
|
||||||
lr=self.hparams.proto_lr)
|
|
||||||
if list(self.backbone.parameters()):
|
if list(self.backbone.parameters()):
|
||||||
# only add an optimizer is the backbone has trainable parameters
|
# only add an optimizer is the backbone has trainable parameters
|
||||||
# otherwise, the next line fails
|
# otherwise, the next line fails
|
||||||
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
|
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||||
|
lr=self.hparams.bb_lr)
|
||||||
return proto_opt, bb_opt
|
return proto_opt, bb_opt
|
||||||
else:
|
else:
|
||||||
return proto_opt
|
return proto_opt
|
||||||
|
62
prototorch/models/knn.py
Normal file
62
prototorch/models/knn.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""The popular K-Nearest-Neighbors classification algorithm."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchmetrics
|
||||||
|
from prototorch.components import LabeledComponents
|
||||||
|
from prototorch.components.initializers import parse_init_arg
|
||||||
|
from prototorch.functions.competitions import knnc
|
||||||
|
from prototorch.functions.distances import euclidean_distance
|
||||||
|
|
||||||
|
from .abstract import AbstractPrototypeModel
|
||||||
|
|
||||||
|
|
||||||
|
class KNN(AbstractPrototypeModel):
|
||||||
|
"""K-Nearest-Neighbors classification algorithm."""
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
# Default Values
|
||||||
|
self.hparams.setdefault("k", 1)
|
||||||
|
self.hparams.setdefault("distance", euclidean_distance)
|
||||||
|
|
||||||
|
data = kwargs.get("data")
|
||||||
|
x_train, y_train = parse_init_arg(data)
|
||||||
|
|
||||||
|
self.proto_layer = LabeledComponents(initialized_components=(x_train,
|
||||||
|
y_train))
|
||||||
|
|
||||||
|
self.train_acc = torchmetrics.Accuracy()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prototype_labels(self):
|
||||||
|
return self.proto_layer.component_labels.detach().cpu()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
protos, _ = self.proto_layer()
|
||||||
|
dis = self.hparams.distance(x, protos)
|
||||||
|
return dis
|
||||||
|
|
||||||
|
def predict(self, x):
|
||||||
|
# model.eval() # ?!
|
||||||
|
with torch.no_grad():
|
||||||
|
d = self(x)
|
||||||
|
plabels = self.proto_layer.component_labels
|
||||||
|
y_pred = knnc(d, plabels, k=self.hparams.k)
|
||||||
|
return y_pred.numpy()
|
||||||
|
|
||||||
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def on_train_batch_start(self,
|
||||||
|
train_batch,
|
||||||
|
batch_idx,
|
||||||
|
dataloader_idx=None):
|
||||||
|
warnings.warn("k-NN has no training, skipping!")
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
return None
|
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from prototorch.components import Components
|
from prototorch.components import Components
|
||||||
from prototorch.components import initializers as cinit
|
from prototorch.components import initializers as cinit
|
||||||
|
from prototorch.components.initializers import ZerosInitializer
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import NeuralGasEnergy
|
from prototorch.modules.losses import NeuralGasEnergy
|
||||||
|
|
||||||
@@ -41,12 +42,14 @@ class NeuralGas(AbstractPrototypeModel):
|
|||||||
|
|
||||||
self.save_hyperparameters(hparams)
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
|
|
||||||
# Default Values
|
# Default Values
|
||||||
self.hparams.setdefault("input_dim", 2)
|
self.hparams.setdefault("input_dim", 2)
|
||||||
self.hparams.setdefault("agelimit", 10)
|
self.hparams.setdefault("agelimit", 10)
|
||||||
self.hparams.setdefault("lm", 1)
|
self.hparams.setdefault("lm", 1)
|
||||||
self.hparams.setdefault("prototype_initializer",
|
self.hparams.setdefault("prototype_initializer",
|
||||||
cinit.ZerosInitializer(self.hparams.input_dim))
|
ZerosInitializer(self.hparams.input_dim))
|
||||||
|
|
||||||
self.proto_layer = Components(
|
self.proto_layer = Components(
|
||||||
self.hparams.num_prototypes,
|
self.hparams.num_prototypes,
|
||||||
|
@@ -298,7 +298,8 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
def precheck(self, trainer):
|
def precheck(self, trainer):
|
||||||
if self.show_last_only:
|
if self.show_last_only:
|
||||||
if trainer.current_epoch != trainer.max_epochs - 1:
|
if trainer.current_epoch != trainer.max_epochs - 1:
|
||||||
return
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def setup_ax(self, xlabel=None, ylabel=None):
|
def setup_ax(self, xlabel=None, ylabel=None):
|
||||||
ax = self.fig.gca()
|
ax = self.fig.gca()
|
||||||
@@ -362,7 +363,8 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
|
|
||||||
class VisGLVQ2D(Vis2DAbstract):
|
class VisGLVQ2D(Vis2DAbstract):
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
self.precheck(trainer)
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
@@ -386,7 +388,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
self.map_protos = map_protos
|
self.map_protos = map_protos
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
self.precheck(trainer)
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
@@ -411,14 +414,15 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
|
|
||||||
class VisCBC2D(Vis2DAbstract):
|
class VisCBC2D(Vis2DAbstract):
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
self.precheck(trainer)
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
protos = pl_module.components
|
protos = pl_module.components
|
||||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||||
ylabel="Data dimension 2")
|
ylabel="Data dimension 2")
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, "w")
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
||||||
@@ -431,7 +435,8 @@ class VisCBC2D(Vis2DAbstract):
|
|||||||
|
|
||||||
class VisNG2D(Vis2DAbstract):
|
class VisNG2D(Vis2DAbstract):
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
self.precheck(trainer)
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
|
4
setup.py
4
setup.py
@@ -19,7 +19,7 @@ DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
|||||||
with open("README.md", "r") as fh:
|
with open("README.md", "r") as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
INSTALL_REQUIRES = ["prototorch", "pytorch_lightning", "torchmetrics"]
|
INSTALL_REQUIRES = ["prototorch>=0.4.4", "pytorch_lightning", "torchmetrics"]
|
||||||
DEV = ["bumpversion"]
|
DEV = ["bumpversion"]
|
||||||
EXAMPLES = ["matplotlib", "scikit-learn"]
|
EXAMPLES = ["matplotlib", "scikit-learn"]
|
||||||
TESTS = ["codecov", "pytest"]
|
TESTS = ["codecov", "pytest"]
|
||||||
@@ -27,7 +27,7 @@ ALL = DEV + EXAMPLES + TESTS
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=safe_name("prototorch_" + PLUGIN_NAME),
|
name=safe_name("prototorch_" + PLUGIN_NAME),
|
||||||
version="0.1.5",
|
version="0.1.7",
|
||||||
description="Pre-packaged prototype-based "
|
description="Pre-packaged prototype-based "
|
||||||
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
|
Reference in New Issue
Block a user