feat: add gmlvq example
it was necessary to update the pre-commit definition for a successfull commit.
This commit is contained in:
parent
4a7d4a3d99
commit
6ed1b9a832
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.1.0
|
rev: v4.4.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@ -13,17 +13,17 @@ repos:
|
|||||||
- id: check-case-conflict
|
- id: check-case-conflict
|
||||||
|
|
||||||
- repo: https://github.com/myint/autoflake
|
- repo: https://github.com/myint/autoflake
|
||||||
rev: v1.4
|
rev: v2.1.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: autoflake
|
- id: autoflake
|
||||||
|
|
||||||
- repo: http://github.com/PyCQA/isort
|
- repo: http://github.com/PyCQA/isort
|
||||||
rev: 5.10.1
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v0.931
|
rev: v1.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
files: prototorch
|
files: prototorch
|
||||||
@ -35,14 +35,14 @@ repos:
|
|||||||
- id: yapf
|
- id: yapf
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||||
rev: v1.9.0
|
rev: v1.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: python-use-type-annotations
|
- id: python-use-type-annotations
|
||||||
- id: python-no-log-warn
|
- id: python-no-log-warn
|
||||||
- id: python-check-blanket-noqa
|
- id: python-check-blanket-noqa
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v2.31.0
|
rev: v3.7.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
|
||||||
|
76
examples/gmlvq.py
Normal file
76
examples/gmlvq.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
"""ProtoTorch CBC example using 2D Iris data."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
|
||||||
|
class GMLVQ(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of Generalized Matrix Learning Vector Quantization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.components_layer = pt.components.LabeledComponents(
|
||||||
|
distribution=[1, 1, 1],
|
||||||
|
components_initializer=pt.initializers.SMCI(data, noise=0.1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.backbone = pt.transforms.Omega(
|
||||||
|
len(data[0][0]),
|
||||||
|
len(data[0][0]),
|
||||||
|
pt.initializers.RandomLinearTransformInitializer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
"""
|
||||||
|
Forward function that returns a tuple of dissimilarities and label information.
|
||||||
|
Feed into GLVQLoss to get a complete GMLVQ model.
|
||||||
|
"""
|
||||||
|
components, label = self.components_layer()
|
||||||
|
|
||||||
|
latent_x = self.backbone(data)
|
||||||
|
latent_components = self.backbone(components)
|
||||||
|
|
||||||
|
distance = pt.distances.squared_euclidean_distance(
|
||||||
|
latent_x, latent_components)
|
||||||
|
|
||||||
|
return distance, label
|
||||||
|
|
||||||
|
def predict(self, data):
|
||||||
|
"""
|
||||||
|
The GMLVQ has a modified prediction step, where a competition layer is applied.
|
||||||
|
"""
|
||||||
|
components, label = self.components_layer()
|
||||||
|
distance = pt.distances.squared_euclidean_distance(data, components)
|
||||||
|
winning_label = pt.competitions.wtac(distance, label)
|
||||||
|
return winning_label
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_ds = pt.datasets.Iris()
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
|
||||||
|
|
||||||
|
model = GMLVQ(train_ds)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
|
||||||
|
criterion = pt.losses.GLVQLoss()
|
||||||
|
|
||||||
|
for epoch in range(200):
|
||||||
|
correct = 0.0
|
||||||
|
for x, y in train_loader:
|
||||||
|
d, labels = model(x)
|
||||||
|
loss = criterion(d, y, labels).mean(0)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
y_pred = model.predict(x)
|
||||||
|
correct += (y_pred == y).float().sum(0)
|
||||||
|
|
||||||
|
acc = 100 * correct / len(train_ds)
|
||||||
|
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
Loading…
Reference in New Issue
Block a user