Compare commits
No commits in common. "main" and "v0.6.0" have entirely different histories.
@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.7.1
|
current_version = 0.6.0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
@ -8,6 +8,6 @@ message = build: bump version {current_version} → {new_version}
|
|||||||
|
|
||||||
[bumpversion:file:pyproject.toml]
|
[bumpversion:file:pyproject.toml]
|
||||||
|
|
||||||
[bumpversion:file:./src/prototorch/models/__init__.py]
|
[bumpversion:file:./prototorch/models/__init__.py]
|
||||||
|
|
||||||
[bumpversion:file:./docs/source/conf.py]
|
[bumpversion:file:./docs/source/conf.py]
|
||||||
|
4
.github/workflows/pythonapp.yml
vendored
4
.github/workflows/pythonapp.yml
vendored
@ -65,9 +65,9 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .[all]
|
pip install .[all]
|
||||||
pip install build
|
pip install wheel
|
||||||
- name: Build package
|
- name: Build package
|
||||||
run: python -m build . -C verbose
|
run: python setup.py sdist bdist_wheel
|
||||||
- name: Publish a Python distribution to PyPI
|
- name: Publish a Python distribution to PyPI
|
||||||
uses: pypa/gh-action-pypi-publish@release/v1
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
with:
|
with:
|
||||||
|
@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.7.1"
|
release = "0.6.0"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""CBC example using the Iris dataset."""
|
"""CBC example using the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
BIN
glvq_iris.ckpt
Normal file
BIN
glvq_iris.ckpt
Normal file
Binary file not shown.
@ -36,4 +36,4 @@ from .unsupervised import (
|
|||||||
)
|
)
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
__version__ = "0.7.1"
|
__version__ = "0.6.0"
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import prototorch
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -227,7 +228,7 @@ class NonGradientMixin(ProtoTorchMixin):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.automatic_optimization = False
|
self.automatic_optimization = False
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
@ -44,7 +44,7 @@ class CBC(SiameseGLVQ):
|
|||||||
probs = self.competition_layer(detections, reasonings)
|
probs = self.competition_layer(detections, reasonings)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
y_pred = self(x)
|
y_pred = self(x)
|
||||||
num_classes = self.num_classes
|
num_classes = self.num_classes
|
||||||
@ -52,8 +52,8 @@ class CBC(SiameseGLVQ):
|
|||||||
loss = self.loss(y_pred, y_true).mean()
|
loss = self.loss(y_pred, y_true).mean()
|
||||||
return y_pred, loss
|
return y_pred, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
y_pred, train_loss = self.shared_step(batch, batch_idx)
|
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
preds = torch.argmax(y_pred, dim=1)
|
preds = torch.argmax(y_pred, dim=1)
|
||||||
accuracy = torchmetrics.functional.accuracy(
|
accuracy = torchmetrics.functional.accuracy(
|
||||||
preds.int(),
|
preds.int(),
|
@ -66,15 +66,15 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
prototype_wr,
|
prototype_wr,
|
||||||
])
|
])
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x)
|
out = self.compute_distances(x)
|
||||||
_, plabels = self.proto_layer()
|
_, plabels = self.proto_layer()
|
||||||
loss = self.loss(out, y, plabels)
|
loss = self.loss(out, y, plabels)
|
||||||
return out, loss
|
return out, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
out, train_loss = self.shared_step(batch, batch_idx)
|
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
self.log_prototype_win_ratios(out)
|
self.log_prototype_win_ratios(out)
|
||||||
self.log("train_loss", train_loss)
|
self.log("train_loss", train_loss)
|
||||||
self.log_acc(out, batch[-1], tag="train_acc")
|
self.log_acc(out, batch[-1], tag="train_acc")
|
||||||
@ -99,6 +99,10 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
test_loss += batch_loss.item()
|
test_loss += batch_loss.item()
|
||||||
self.log("test_loss", test_loss)
|
self.log("test_loss", test_loss)
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||||
|
# pass
|
||||||
|
|
||||||
|
|
||||||
class SiameseGLVQ(GLVQ):
|
class SiameseGLVQ(GLVQ):
|
||||||
"""GLVQ in a Siamese setting.
|
"""GLVQ in a Siamese setting.
|
@ -34,7 +34,7 @@ class KNN(SupervisedPrototypeModel):
|
|||||||
labels_initializer=LiteralLabelsInitializer(targets))
|
labels_initializer=LiteralLabelsInitializer(targets))
|
||||||
self.competition_layer = KNNC(k=self.hparams.k)
|
self.competition_layer = KNNC(k=self.hparams.k)
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
return 1 # skip training step
|
return 1 # skip training step
|
||||||
|
|
||||||
def on_train_batch_start(self, train_batch, batch_idx):
|
def on_train_batch_start(self, train_batch, batch_idx):
|
@ -13,7 +13,7 @@ from .glvq import GLVQ
|
|||||||
class LVQ1(NonGradientMixin, GLVQ):
|
class LVQ1(NonGradientMixin, GLVQ):
|
||||||
"""Learning Vector Quantization 1."""
|
"""Learning Vector Quantization 1."""
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
protos, plables = self.proto_layer()
|
protos, plables = self.proto_layer()
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self.compute_distances(x)
|
dis = self.compute_distances(x)
|
||||||
@ -43,7 +43,7 @@ class LVQ1(NonGradientMixin, GLVQ):
|
|||||||
class LVQ21(NonGradientMixin, GLVQ):
|
class LVQ21(NonGradientMixin, GLVQ):
|
||||||
"""Learning Vector Quantization 2.1."""
|
"""Learning Vector Quantization 2.1."""
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
protos, plabels = self.proto_layer()
|
protos, plabels = self.proto_layer()
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
@ -100,7 +100,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
|
|||||||
lower_bound = (gamma * f.log()).sum()
|
lower_bound = (gamma * f.log()).sum()
|
||||||
return lower_bound
|
return lower_bound
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
protos, plabels = self.proto_layer()
|
protos, plabels = self.proto_layer()
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
@ -21,7 +21,7 @@ class CELVQ(GLVQ):
|
|||||||
# Loss
|
# Loss
|
||||||
self.loss = torch.nn.CrossEntropyLoss()
|
self.loss = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x) # [None, num_protos]
|
out = self.compute_distances(x) # [None, num_protos]
|
||||||
_, plabels = self.proto_layer()
|
_, plabels = self.proto_layer()
|
||||||
@ -63,7 +63,7 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
prediction[confidence < self.rejection_confidence] = -1
|
prediction[confidence < self.rejection_confidence] = -1
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.forward(x)
|
out = self.forward(x)
|
||||||
_, plabels = self.proto_layer()
|
_, plabels = self.proto_layer()
|
||||||
@ -123,7 +123,7 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
|||||||
self.loss = torch.nn.KLDivLoss()
|
self.loss = torch.nn.KLDivLoss()
|
||||||
|
|
||||||
# FIXME
|
# FIXME
|
||||||
# def training_step(self, batch, batch_idx):
|
# def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
# x, y = batch
|
# x, y = batch
|
||||||
# y_pred = self(x)
|
# y_pred = self(x)
|
||||||
# batch_loss = self.loss(y_pred, y)
|
# batch_loss = self.loss(y_pred, y)
|
@ -1,7 +1,7 @@
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "prototorch-models"
|
name = "prototorch-models"
|
||||||
version = "0.7.1"
|
version = "0.6.0"
|
||||||
description = "Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning."
|
description = "Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning."
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Jensun Ravichandran", email = "jjensun@gmail.com" },
|
{ name = "Jensun Ravichandran", email = "jjensun@gmail.com" },
|
||||||
@ -64,6 +64,9 @@ all = [
|
|||||||
"ipykernel",
|
"ipykernel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.entry-points."prototorch.plugins"]
|
||||||
|
models = "prototorch.models"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=61", "wheel"]
|
requires = ["setuptools>=61", "wheel"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
@ -88,3 +91,6 @@ line_length = 79
|
|||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
explicit_package_bases = true
|
explicit_package_bases = true
|
||||||
namespace_packages = true
|
namespace_packages = true
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
py-modules = ["prototorch"]
|
||||||
|
@ -1,193 +1,195 @@
|
|||||||
"""prototorch.models test suite."""
|
"""prototorch.models test suite."""
|
||||||
|
|
||||||
import prototorch.models
|
import prototorch as pt
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def test_glvq_model_build():
|
def test_glvq_model_build():
|
||||||
model = prototorch.models.GLVQ(
|
model = pt.models.GLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_glvq1_model_build():
|
def test_glvq1_model_build():
|
||||||
model = prototorch.models.GLVQ1(
|
model = pt.models.GLVQ1(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_glvq21_model_build():
|
def test_glvq21_model_build():
|
||||||
model = prototorch.models.GLVQ1(
|
model = pt.models.GLVQ1(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_gmlvq_model_build():
|
def test_gmlvq_model_build():
|
||||||
model = prototorch.models.GMLVQ(
|
model = pt.models.GMLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 2,
|
"input_dim": 2,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_grlvq_model_build():
|
def test_grlvq_model_build():
|
||||||
model = prototorch.models.GRLVQ(
|
model = pt.models.GRLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 2,
|
"input_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_gtlvq_model_build():
|
def test_gtlvq_model_build():
|
||||||
model = prototorch.models.GTLVQ(
|
model = pt.models.GTLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 4,
|
"input_dim": 4,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_lgmlvq_model_build():
|
def test_lgmlvq_model_build():
|
||||||
model = prototorch.models.LGMLVQ(
|
model = pt.models.LGMLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 4,
|
"input_dim": 4,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_image_glvq_model_build():
|
def test_image_glvq_model_build():
|
||||||
model = prototorch.models.ImageGLVQ(
|
model = pt.models.ImageGLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(16),
|
prototypes_initializer=pt.initializers.RNCI(16),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_image_gmlvq_model_build():
|
def test_image_gmlvq_model_build():
|
||||||
model = prototorch.models.ImageGMLVQ(
|
model = pt.models.ImageGMLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 16,
|
"input_dim": 16,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(16),
|
prototypes_initializer=pt.initializers.RNCI(16),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_image_gtlvq_model_build():
|
def test_image_gtlvq_model_build():
|
||||||
model = prototorch.models.ImageGMLVQ(
|
model = pt.models.ImageGMLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 16,
|
"input_dim": 16,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(16),
|
prototypes_initializer=pt.initializers.RNCI(16),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_siamese_glvq_model_build():
|
def test_siamese_glvq_model_build():
|
||||||
model = prototorch.models.SiameseGLVQ(
|
model = pt.models.SiameseGLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(4),
|
prototypes_initializer=pt.initializers.RNCI(4),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_siamese_gmlvq_model_build():
|
def test_siamese_gmlvq_model_build():
|
||||||
model = prototorch.models.SiameseGMLVQ(
|
model = pt.models.SiameseGMLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 4,
|
"input_dim": 4,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(4),
|
prototypes_initializer=pt.initializers.RNCI(4),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_siamese_gtlvq_model_build():
|
def test_siamese_gtlvq_model_build():
|
||||||
model = prototorch.models.SiameseGTLVQ(
|
model = pt.models.SiameseGTLVQ(
|
||||||
{
|
{
|
||||||
"distribution": (3, 2),
|
"distribution": (3, 2),
|
||||||
"input_dim": 4,
|
"input_dim": 4,
|
||||||
"latent_dim": 2,
|
"latent_dim": 2,
|
||||||
},
|
},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(4),
|
prototypes_initializer=pt.initializers.RNCI(4),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_knn_model_build():
|
def test_knn_model_build():
|
||||||
train_ds = prototorch.datasets.Iris(dims=[0, 2])
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
model = prototorch.models.KNN(dict(k=3), data=train_ds)
|
model = pt.models.KNN(dict(k=3), data=train_ds)
|
||||||
|
|
||||||
|
|
||||||
def test_lvq1_model_build():
|
def test_lvq1_model_build():
|
||||||
model = prototorch.models.LVQ1(
|
model = pt.models.LVQ1(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_lvq21_model_build():
|
def test_lvq21_model_build():
|
||||||
model = prototorch.models.LVQ21(
|
model = pt.models.LVQ21(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_median_lvq_model_build():
|
def test_median_lvq_model_build():
|
||||||
model = prototorch.models.MedianLVQ(
|
model = pt.models.MedianLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_celvq_model_build():
|
def test_celvq_model_build():
|
||||||
model = prototorch.models.CELVQ(
|
model = pt.models.CELVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_rslvq_model_build():
|
def test_rslvq_model_build():
|
||||||
model = prototorch.models.RSLVQ(
|
model = pt.models.RSLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_slvq_model_build():
|
def test_slvq_model_build():
|
||||||
model = prototorch.models.SLVQ(
|
model = pt.models.SLVQ(
|
||||||
{"distribution": (3, 2)},
|
{"distribution": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_growing_neural_gas_model_build():
|
def test_growing_neural_gas_model_build():
|
||||||
model = prototorch.models.GrowingNeuralGas(
|
model = pt.models.GrowingNeuralGas(
|
||||||
{"num_prototypes": 5},
|
{"num_prototypes": 5},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_kohonen_som_model_build():
|
def test_kohonen_som_model_build():
|
||||||
model = prototorch.models.KohonenSOM(
|
model = pt.models.KohonenSOM(
|
||||||
{"shape": (3, 2)},
|
{"shape": (3, 2)},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_neural_gas_model_build():
|
def test_neural_gas_model_build():
|
||||||
model = prototorch.models.NeuralGas(
|
model = pt.models.NeuralGas(
|
||||||
{"num_prototypes": 5},
|
{"num_prototypes": 5},
|
||||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
prototypes_initializer=pt.initializers.RNCI(2),
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user