From 6ed1b9a8325f2794ee01b5bba485042766e10944 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Tue, 20 Jun 2023 15:12:32 +0200 Subject: [PATCH] feat: add gmlvq example it was necessary to update the pre-commit definition for a successfull commit. --- .pre-commit-config.yaml | 12 +++---- examples/gmlvq.py | 76 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 examples/gmlvq.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 94784d7..a892e4e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 + rev: v4.4.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -13,17 +13,17 @@ repos: - id: check-case-conflict - repo: https://github.com/myint/autoflake - rev: v1.4 + rev: v2.1.1 hooks: - id: autoflake - repo: http://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 + rev: v1.3.0 hooks: - id: mypy files: prototorch @@ -35,14 +35,14 @@ repos: - id: yapf - repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.9.0 + rev: v1.10.0 hooks: - id: python-use-type-annotations - id: python-no-log-warn - id: python-check-blanket-noqa - repo: https://github.com/asottile/pyupgrade - rev: v2.31.0 + rev: v3.7.0 hooks: - id: pyupgrade diff --git a/examples/gmlvq.py b/examples/gmlvq.py new file mode 100644 index 0000000..23b8924 --- /dev/null +++ b/examples/gmlvq.py @@ -0,0 +1,76 @@ +"""ProtoTorch CBC example using 2D Iris data.""" + +import torch + +import prototorch as pt + + +class GMLVQ(torch.nn.Module): + """ + Implementation of Generalized Matrix Learning Vector Quantization. + """ + + def __init__(self, data, **kwargs): + super().__init__(**kwargs) + + self.components_layer = pt.components.LabeledComponents( + distribution=[1, 1, 1], + components_initializer=pt.initializers.SMCI(data, noise=0.1), + ) + + self.backbone = pt.transforms.Omega( + len(data[0][0]), + len(data[0][0]), + pt.initializers.RandomLinearTransformInitializer(), + ) + + def forward(self, data): + """ + Forward function that returns a tuple of dissimilarities and label information. + Feed into GLVQLoss to get a complete GMLVQ model. + """ + components, label = self.components_layer() + + latent_x = self.backbone(data) + latent_components = self.backbone(components) + + distance = pt.distances.squared_euclidean_distance( + latent_x, latent_components) + + return distance, label + + def predict(self, data): + """ + The GMLVQ has a modified prediction step, where a competition layer is applied. + """ + components, label = self.components_layer() + distance = pt.distances.squared_euclidean_distance(data, components) + winning_label = pt.competitions.wtac(distance, label) + return winning_label + + +if __name__ == "__main__": + train_ds = pt.datasets.Iris() + + train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32) + + model = GMLVQ(train_ds) + optimizer = torch.optim.Adam(model.parameters(), lr=0.05) + criterion = pt.losses.GLVQLoss() + + for epoch in range(200): + correct = 0.0 + for x, y in train_loader: + d, labels = model(x) + loss = criterion(d, y, labels).mean(0) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.no_grad(): + y_pred = model.predict(x) + correct += (y_pred == y).float().sum(0) + + acc = 100 * correct / len(train_ds) + print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")