Compare commits
15 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
adafb49985
|
||
|
78f8b6cc00
|
||
|
c6f718a1d4
|
||
|
1786031b4e
|
||
|
824dfced92
|
||
|
d4bf6dbbe9 | ||
|
c99fdb436c | ||
|
28ac5f5ed9 | ||
|
b7f510a9fe | ||
|
781ef93b06 | ||
|
072e61b3cd | ||
|
71167a8f77 | ||
|
60990f42d2 | ||
|
1e83c439f7 | ||
|
cbbbbeda98 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.6.0
|
||||
current_version = 0.7.1
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||
@@ -8,6 +8,6 @@ message = build: bump version {current_version} → {new_version}
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
||||
|
||||
[bumpversion:file:./prototorch/models/__init__.py]
|
||||
[bumpversion:file:./src/prototorch/models/__init__.py]
|
||||
|
||||
[bumpversion:file:./docs/source/conf.py]
|
||||
|
4
.github/workflows/pythonapp.yml
vendored
4
.github/workflows/pythonapp.yml
vendored
@@ -65,9 +65,9 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
pip install wheel
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python setup.py sdist bdist_wheel
|
||||
run: python -m build . -C verbose
|
||||
- name: Publish a Python distribution to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
#
|
||||
release = "0.6.0"
|
||||
release = "0.7.1"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
|
@@ -1,5 +1,4 @@
|
||||
"""CBC example using the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
|
BIN
glvq_iris.ckpt
BIN
glvq_iris.ckpt
Binary file not shown.
@@ -1,7 +1,7 @@
|
||||
|
||||
[project]
|
||||
name = "prototorch-models"
|
||||
version = "0.6.0"
|
||||
version = "0.7.1"
|
||||
description = "Pre-packaged prototype-based machine learning models using ProtoTorch and PyTorch-Lightning."
|
||||
authors = [
|
||||
{ name = "Jensun Ravichandran", email = "jjensun@gmail.com" },
|
||||
@@ -64,9 +64,6 @@ all = [
|
||||
"ipykernel",
|
||||
]
|
||||
|
||||
[project.entry-points."prototorch.plugins"]
|
||||
models = "prototorch.models"
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
@@ -91,6 +88,3 @@ line_length = 79
|
||||
[tool.mypy]
|
||||
explicit_package_bases = true
|
||||
namespace_packages = true
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["prototorch"]
|
||||
|
@@ -36,4 +36,4 @@ from .unsupervised import (
|
||||
)
|
||||
from .vis import *
|
||||
|
||||
__version__ = "0.6.0"
|
||||
__version__ = "0.7.1"
|
@@ -2,7 +2,6 @@
|
||||
|
||||
import logging
|
||||
|
||||
import prototorch
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -228,7 +227,7 @@ class NonGradientMixin(ProtoTorchMixin):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.automatic_optimization = False
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@@ -44,7 +44,7 @@ class CBC(SiameseGLVQ):
|
||||
probs = self.competition_layer(detections, reasonings)
|
||||
return probs
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
def shared_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_pred = self(x)
|
||||
num_classes = self.num_classes
|
||||
@@ -52,8 +52,8 @@ class CBC(SiameseGLVQ):
|
||||
loss = self.loss(y_pred, y_true).mean()
|
||||
return y_pred, loss
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
def training_step(self, batch, batch_idx):
|
||||
y_pred, train_loss = self.shared_step(batch, batch_idx)
|
||||
preds = torch.argmax(y_pred, dim=1)
|
||||
accuracy = torchmetrics.functional.accuracy(
|
||||
preds.int(),
|
@@ -1,13 +1,15 @@
|
||||
"""Models based on the GLVQ framework."""
|
||||
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
from prototorch.core.competitions import wtac
|
||||
from prototorch.core.distances import (
|
||||
ML_omega_distance,
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from prototorch.core.initializers import EyeLinearTransformInitializer
|
||||
from prototorch.core.initializers import LLTI, EyeLinearTransformInitializer
|
||||
from prototorch.core.losses import (
|
||||
GLVQLoss,
|
||||
lvq1_loss,
|
||||
@@ -15,7 +17,7 @@ from prototorch.core.losses import (
|
||||
)
|
||||
from prototorch.core.transforms import LinearTransform
|
||||
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import Parameter, ParameterList
|
||||
|
||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||
from .extras import ltangent_distance, orthogonalization
|
||||
@@ -45,36 +47,38 @@ class GLVQ(SupervisedPrototypeModel):
|
||||
|
||||
def initialize_prototype_win_ratios(self):
|
||||
self.register_buffer(
|
||||
"prototype_win_ratios",
|
||||
torch.zeros(self.num_prototypes, device=self.device))
|
||||
"prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device)
|
||||
)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
self.initialize_prototype_win_ratios()
|
||||
|
||||
def log_prototype_win_ratios(self, distances):
|
||||
batch_size = len(distances)
|
||||
prototype_wc = torch.zeros(self.num_prototypes,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
wi, wc = torch.unique(distances.min(dim=-1).indices,
|
||||
sorted=True,
|
||||
return_counts=True)
|
||||
prototype_wc = torch.zeros(
|
||||
self.num_prototypes, dtype=torch.long, device=self.device
|
||||
)
|
||||
wi, wc = torch.unique(
|
||||
distances.min(dim=-1).indices, sorted=True, return_counts=True
|
||||
)
|
||||
prototype_wc[wi] = wc
|
||||
prototype_wr = prototype_wc / batch_size
|
||||
self.prototype_win_ratios = torch.vstack([
|
||||
self.prototype_win_ratios,
|
||||
prototype_wr,
|
||||
])
|
||||
self.prototype_win_ratios = torch.vstack(
|
||||
[
|
||||
self.prototype_win_ratios,
|
||||
prototype_wr,
|
||||
]
|
||||
)
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
def shared_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
out = self.compute_distances(x)
|
||||
_, plabels = self.proto_layer()
|
||||
loss = self.loss(out, y, plabels)
|
||||
return out, loss
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
def training_step(self, batch, batch_idx):
|
||||
out, train_loss = self.shared_step(batch, batch_idx)
|
||||
self.log_prototype_win_ratios(out)
|
||||
self.log("train_loss", train_loss)
|
||||
self.log_acc(out, batch[-1], tag="train_acc")
|
||||
@@ -99,10 +103,6 @@ class GLVQ(SupervisedPrototypeModel):
|
||||
test_loss += batch_loss.item()
|
||||
self.log("test_loss", test_loss)
|
||||
|
||||
# TODO
|
||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
# pass
|
||||
|
||||
|
||||
class SiameseGLVQ(GLVQ):
|
||||
"""GLVQ in a Siamese setting.
|
||||
@@ -113,11 +113,9 @@ class SiameseGLVQ(GLVQ):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hparams,
|
||||
backbone=torch.nn.Identity(),
|
||||
both_path_gradients=False,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self, hparams, backbone=torch.nn.Identity(), both_path_gradients=False, **kwargs
|
||||
):
|
||||
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
self.backbone = backbone
|
||||
@@ -179,6 +177,7 @@ class GRLVQ(SiameseGLVQ):
|
||||
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
||||
|
||||
"""
|
||||
|
||||
_relevances: torch.Tensor
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
@@ -189,8 +188,7 @@ class GRLVQ(SiameseGLVQ):
|
||||
self.register_parameter("_relevances", Parameter(relevances))
|
||||
|
||||
# Override the backbone
|
||||
self.backbone = LambdaLayer(self._apply_relevances,
|
||||
name="relevance scaling")
|
||||
self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling")
|
||||
|
||||
def _apply_relevances(self, x):
|
||||
return x @ torch.diag(self._relevances)
|
||||
@@ -214,8 +212,9 @@ class SiameseGMLVQ(SiameseGLVQ):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Override the backbone
|
||||
omega_initializer = kwargs.get("omega_initializer",
|
||||
EyeLinearTransformInitializer())
|
||||
omega_initializer = kwargs.get(
|
||||
"omega_initializer", EyeLinearTransformInitializer()
|
||||
)
|
||||
self.backbone = LinearTransform(
|
||||
self.hparams["input_dim"],
|
||||
self.hparams["latent_dim"],
|
||||
@@ -233,6 +232,49 @@ class SiameseGMLVQ(SiameseGLVQ):
|
||||
return lam.detach().cpu()
|
||||
|
||||
|
||||
class GMLMLVQ(GLVQ):
|
||||
"""Generalized Multi-Layer Matrix Learning Vector Quantization.
|
||||
Masks are applied to the omega layers to achieve sparsity and constrain
|
||||
learning to certain items of each omega.
|
||||
|
||||
Implemented as a regular GLVQ network that simply uses a different distance
|
||||
function. This makes it easier to implement a localized variant.
|
||||
"""
|
||||
|
||||
# Parameters
|
||||
_omegas: list[torch.Tensor]
|
||||
masks: list[torch.Tensor]
|
||||
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
# Additional parameters
|
||||
self._masks = ParameterList(
|
||||
[Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")]
|
||||
)
|
||||
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks])
|
||||
|
||||
@property
|
||||
def omega_matrices(self):
|
||||
return [_omega.detach().cpu() for _omega in self._omegas]
|
||||
|
||||
@property
|
||||
def lambda_matrix(self):
|
||||
# TODO update to respective lambda calculation rules.
|
||||
omega = self._omega.detach() # (input_dim, latent_dim)
|
||||
lam = omega @ omega.T
|
||||
return lam.detach().cpu()
|
||||
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos, self._omegas, self._masks)
|
||||
return distances
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(omegas): (shapes: {[tuple(_omega.shape) for _omega in self._omegas]})"
|
||||
|
||||
|
||||
class GMLVQ(GLVQ):
|
||||
"""Generalized Matrix Learning Vector Quantization.
|
||||
|
||||
@@ -249,10 +291,12 @@ class GMLVQ(GLVQ):
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
# Additional parameters
|
||||
omega_initializer = kwargs.get("omega_initializer",
|
||||
EyeLinearTransformInitializer())
|
||||
omega = omega_initializer.generate(self.hparams["input_dim"],
|
||||
self.hparams["latent_dim"])
|
||||
omega_initializer = kwargs.get(
|
||||
"omega_initializer", EyeLinearTransformInitializer()
|
||||
)
|
||||
omega = omega_initializer.generate(
|
||||
self.hparams["input_dim"], self.hparams["latent_dim"]
|
||||
)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
|
||||
@property
|
@@ -34,7 +34,7 @@ class KNN(SupervisedPrototypeModel):
|
||||
labels_initializer=LiteralLabelsInitializer(targets))
|
||||
self.competition_layer = KNNC(k=self.hparams.k)
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
return 1 # skip training step
|
||||
|
||||
def on_train_batch_start(self, train_batch, batch_idx):
|
@@ -13,7 +13,7 @@ from .glvq import GLVQ
|
||||
class LVQ1(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 1."""
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
protos, plables = self.proto_layer()
|
||||
x, y = train_batch
|
||||
dis = self.compute_distances(x)
|
||||
@@ -43,7 +43,7 @@ class LVQ1(NonGradientMixin, GLVQ):
|
||||
class LVQ21(NonGradientMixin, GLVQ):
|
||||
"""Learning Vector Quantization 2.1."""
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
protos, plabels = self.proto_layer()
|
||||
|
||||
x, y = train_batch
|
||||
@@ -100,7 +100,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
|
||||
lower_bound = (gamma * f.log()).sum()
|
||||
return lower_bound
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
protos, plabels = self.proto_layer()
|
||||
|
||||
x, y = train_batch
|
@@ -21,7 +21,7 @@ class CELVQ(GLVQ):
|
||||
# Loss
|
||||
self.loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
def shared_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
out = self.compute_distances(x) # [None, num_protos]
|
||||
_, plabels = self.proto_layer()
|
||||
@@ -63,7 +63,7 @@ class ProbabilisticLVQ(GLVQ):
|
||||
prediction[confidence < self.rejection_confidence] = -1
|
||||
return prediction
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
out = self.forward(x)
|
||||
_, plabels = self.proto_layer()
|
||||
@@ -123,7 +123,7 @@ class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
self.loss = torch.nn.KLDivLoss()
|
||||
|
||||
# FIXME
|
||||
# def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
# def training_step(self, batch, batch_idx):
|
||||
# x, y = batch
|
||||
# y_pred = self(x)
|
||||
# batch_loss = self.loss(y_pred, y)
|
@@ -1,195 +1,193 @@
|
||||
"""prototorch.models test suite."""
|
||||
|
||||
import prototorch as pt
|
||||
import pytest
|
||||
import torch
|
||||
import prototorch.models
|
||||
|
||||
|
||||
def test_glvq_model_build():
|
||||
model = pt.models.GLVQ(
|
||||
model = prototorch.models.GLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_glvq1_model_build():
|
||||
model = pt.models.GLVQ1(
|
||||
model = prototorch.models.GLVQ1(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_glvq21_model_build():
|
||||
model = pt.models.GLVQ1(
|
||||
model = prototorch.models.GLVQ1(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_gmlvq_model_build():
|
||||
model = pt.models.GMLVQ(
|
||||
model = prototorch.models.GMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 2,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_grlvq_model_build():
|
||||
model = pt.models.GRLVQ(
|
||||
model = prototorch.models.GRLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_gtlvq_model_build():
|
||||
model = pt.models.GTLVQ(
|
||||
model = prototorch.models.GTLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_lgmlvq_model_build():
|
||||
model = pt.models.LGMLVQ(
|
||||
model = prototorch.models.LGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_image_glvq_model_build():
|
||||
model = pt.models.ImageGLVQ(
|
||||
model = prototorch.models.ImageGLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(16),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(16),
|
||||
)
|
||||
|
||||
|
||||
def test_image_gmlvq_model_build():
|
||||
model = pt.models.ImageGMLVQ(
|
||||
model = prototorch.models.ImageGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 16,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(16),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(16),
|
||||
)
|
||||
|
||||
|
||||
def test_image_gtlvq_model_build():
|
||||
model = pt.models.ImageGMLVQ(
|
||||
model = prototorch.models.ImageGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 16,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(16),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(16),
|
||||
)
|
||||
|
||||
|
||||
def test_siamese_glvq_model_build():
|
||||
model = pt.models.SiameseGLVQ(
|
||||
model = prototorch.models.SiameseGLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(4),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(4),
|
||||
)
|
||||
|
||||
|
||||
def test_siamese_gmlvq_model_build():
|
||||
model = pt.models.SiameseGMLVQ(
|
||||
model = prototorch.models.SiameseGMLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(4),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(4),
|
||||
)
|
||||
|
||||
|
||||
def test_siamese_gtlvq_model_build():
|
||||
model = pt.models.SiameseGTLVQ(
|
||||
model = prototorch.models.SiameseGTLVQ(
|
||||
{
|
||||
"distribution": (3, 2),
|
||||
"input_dim": 4,
|
||||
"latent_dim": 2,
|
||||
},
|
||||
prototypes_initializer=pt.initializers.RNCI(4),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(4),
|
||||
)
|
||||
|
||||
|
||||
def test_knn_model_build():
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
model = pt.models.KNN(dict(k=3), data=train_ds)
|
||||
train_ds = prototorch.datasets.Iris(dims=[0, 2])
|
||||
model = prototorch.models.KNN(dict(k=3), data=train_ds)
|
||||
|
||||
|
||||
def test_lvq1_model_build():
|
||||
model = pt.models.LVQ1(
|
||||
model = prototorch.models.LVQ1(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_lvq21_model_build():
|
||||
model = pt.models.LVQ21(
|
||||
model = prototorch.models.LVQ21(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_median_lvq_model_build():
|
||||
model = pt.models.MedianLVQ(
|
||||
model = prototorch.models.MedianLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_celvq_model_build():
|
||||
model = pt.models.CELVQ(
|
||||
model = prototorch.models.CELVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_rslvq_model_build():
|
||||
model = pt.models.RSLVQ(
|
||||
model = prototorch.models.RSLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_slvq_model_build():
|
||||
model = pt.models.SLVQ(
|
||||
model = prototorch.models.SLVQ(
|
||||
{"distribution": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_growing_neural_gas_model_build():
|
||||
model = pt.models.GrowingNeuralGas(
|
||||
model = prototorch.models.GrowingNeuralGas(
|
||||
{"num_prototypes": 5},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_kohonen_som_model_build():
|
||||
model = pt.models.KohonenSOM(
|
||||
model = prototorch.models.KohonenSOM(
|
||||
{"shape": (3, 2)},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
|
||||
def test_neural_gas_model_build():
|
||||
model = pt.models.NeuralGas(
|
||||
model = prototorch.models.NeuralGas(
|
||||
{"num_prototypes": 5},
|
||||
prototypes_initializer=pt.initializers.RNCI(2),
|
||||
prototypes_initializer=prototorch.initializers.RNCI(2),
|
||||
)
|
||||
|
Reference in New Issue
Block a user