Compare commits
8 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
391473adf3 | ||
|
0d8db31ff2 | ||
|
89b96f0a98 | ||
|
ee4cf583e3 | ||
|
6ed1b9a832 | ||
|
4a7d4a3d99 | ||
|
0626af207f | ||
|
7b23983887 |
@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.7.4
|
current_version = 0.7.6
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
|
26
.github/workflows/pythonapp.yml
vendored
26
.github/workflows/pythonapp.yml
vendored
@ -12,36 +12,36 @@ jobs:
|
|||||||
style:
|
style:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python 3.11
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.11"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .[all]
|
pip install .[all]
|
||||||
- uses: pre-commit/action@v2.0.3
|
- uses: pre-commit/action@v3.0.0
|
||||||
compatibility:
|
compatibility:
|
||||||
needs: style
|
needs: style
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||||
os: [ubuntu-latest, windows-latest]
|
os: [ubuntu-latest, windows-latest]
|
||||||
exclude:
|
exclude:
|
||||||
- os: windows-latest
|
|
||||||
python-version: "3.7"
|
|
||||||
- os: windows-latest
|
- os: windows-latest
|
||||||
python-version: "3.8"
|
python-version: "3.8"
|
||||||
- os: windows-latest
|
- os: windows-latest
|
||||||
python-version: "3.9"
|
python-version: "3.9"
|
||||||
|
- os: windows-latest
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
@ -56,11 +56,11 @@ jobs:
|
|||||||
needs: compatibility
|
needs: compatibility
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.10
|
- name: Set up Python 3.10
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.11"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.1.0
|
rev: v4.4.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@ -13,17 +13,17 @@ repos:
|
|||||||
- id: check-case-conflict
|
- id: check-case-conflict
|
||||||
|
|
||||||
- repo: https://github.com/myint/autoflake
|
- repo: https://github.com/myint/autoflake
|
||||||
rev: v1.4
|
rev: v2.1.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: autoflake
|
- id: autoflake
|
||||||
|
|
||||||
- repo: http://github.com/PyCQA/isort
|
- repo: http://github.com/PyCQA/isort
|
||||||
rev: 5.10.1
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v0.931
|
rev: v1.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
files: prototorch
|
files: prototorch
|
||||||
@ -35,14 +35,14 @@ repos:
|
|||||||
- id: yapf
|
- id: yapf
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||||
rev: v1.9.0
|
rev: v1.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: python-use-type-annotations
|
- id: python-use-type-annotations
|
||||||
- id: python-no-log-warn
|
- id: python-no-log-warn
|
||||||
- id: python-check-blanket-noqa
|
- id: python-check-blanket-noqa
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v2.31.0
|
rev: v3.7.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.7.4"
|
release = "0.7.6"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
76
examples/gmlvq.py
Normal file
76
examples/gmlvq.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
"""ProtoTorch GMLVQ example using Iris data."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
|
||||||
|
class GMLVQ(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of Generalized Matrix Learning Vector Quantization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.components_layer = pt.components.LabeledComponents(
|
||||||
|
distribution=[1, 1, 1],
|
||||||
|
components_initializer=pt.initializers.SMCI(data, noise=0.1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.backbone = pt.transforms.Omega(
|
||||||
|
len(data[0][0]),
|
||||||
|
len(data[0][0]),
|
||||||
|
pt.initializers.RandomLinearTransformInitializer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
"""
|
||||||
|
Forward function that returns a tuple of dissimilarities and label information.
|
||||||
|
Feed into GLVQLoss to get a complete GMLVQ model.
|
||||||
|
"""
|
||||||
|
components, label = self.components_layer()
|
||||||
|
|
||||||
|
latent_x = self.backbone(data)
|
||||||
|
latent_components = self.backbone(components)
|
||||||
|
|
||||||
|
distance = pt.distances.squared_euclidean_distance(
|
||||||
|
latent_x, latent_components)
|
||||||
|
|
||||||
|
return distance, label
|
||||||
|
|
||||||
|
def predict(self, data):
|
||||||
|
"""
|
||||||
|
The GMLVQ has a modified prediction step, where a competition layer is applied.
|
||||||
|
"""
|
||||||
|
components, label = self.components_layer()
|
||||||
|
distance = pt.distances.squared_euclidean_distance(data, components)
|
||||||
|
winning_label = pt.competitions.wtac(distance, label)
|
||||||
|
return winning_label
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_ds = pt.datasets.Iris()
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
|
||||||
|
|
||||||
|
model = GMLVQ(train_ds)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
|
||||||
|
criterion = pt.losses.GLVQLoss()
|
||||||
|
|
||||||
|
for epoch in range(200):
|
||||||
|
correct = 0.0
|
||||||
|
for x, y in train_loader:
|
||||||
|
d, labels = model(x)
|
||||||
|
loss = criterion(d, y, labels).mean(0)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
y_pred = model.predict(x)
|
||||||
|
correct += (y_pred == y).float().sum(0)
|
||||||
|
|
||||||
|
acc = 100 * correct / len(train_ds)
|
||||||
|
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
@ -17,7 +17,7 @@ from .core import similarities # noqa: F401
|
|||||||
from .core import transforms # noqa: F401
|
from .core import transforms # noqa: F401
|
||||||
|
|
||||||
# Core Setup
|
# Core Setup
|
||||||
__version__ = "0.7.4"
|
__version__ = "0.7.6"
|
||||||
|
|
||||||
__all_core__ = [
|
__all_core__ = [
|
||||||
"competitions",
|
"competitions",
|
||||||
|
@ -11,7 +11,7 @@ def squared_euclidean_distance(x, y):
|
|||||||
**Alias:**
|
**Alias:**
|
||||||
``prototorch.functions.distances.sed``
|
``prototorch.functions.distances.sed``
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
expanded_x = x.unsqueeze(dim=1)
|
expanded_x = x.unsqueeze(dim=1)
|
||||||
batchwise_difference = y - expanded_x
|
batchwise_difference = y - expanded_x
|
||||||
differences_raised = torch.pow(batchwise_difference, 2)
|
differences_raised = torch.pow(batchwise_difference, 2)
|
||||||
@ -27,14 +27,14 @@ def euclidean_distance(x, y):
|
|||||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||||
:rtype: `torch.tensor`
|
:rtype: `torch.tensor`
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
distances_raised = squared_euclidean_distance(x, y)
|
distances_raised = squared_euclidean_distance(x, y)
|
||||||
distances = torch.sqrt(distances_raised)
|
distances = torch.sqrt(distances_raised)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
def euclidean_distance_v2(x, y):
|
def euclidean_distance_v2(x, y):
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
diff = y - x.unsqueeze(1)
|
diff = y - x.unsqueeze(1)
|
||||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||||
@ -54,7 +54,7 @@ def lpnorm_distance(x, y, p):
|
|||||||
|
|
||||||
:param p: p parameter of the lp norm
|
:param p: p parameter of the lp norm
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
distances = torch.cdist(x, y, p=p)
|
distances = torch.cdist(x, y, p=p)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ def omega_distance(x, y, omega):
|
|||||||
|
|
||||||
:param `torch.tensor` omega: Two dimensional matrix
|
:param `torch.tensor` omega: Two dimensional matrix
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
projected_x = x @ omega
|
projected_x = x @ omega
|
||||||
projected_y = y @ omega
|
projected_y = y @ omega
|
||||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||||
@ -80,7 +80,7 @@ def lomega_distance(x, y, omegas):
|
|||||||
|
|
||||||
:param `torch.tensor` omegas: Three dimensional matrix
|
:param `torch.tensor` omegas: Three dimensional matrix
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
projected_x = x @ omegas
|
projected_x = x @ omegas
|
||||||
projected_y = torch.diagonal(y @ omegas).T
|
projected_y = torch.diagonal(y @ omegas).T
|
||||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||||
|
@ -21,7 +21,7 @@ def cosine_similarity(x, y):
|
|||||||
Expected dimension of x is 2.
|
Expected dimension of x is 2.
|
||||||
Expected dimension of y is 2.
|
Expected dimension of y is 2.
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
norm_x = x.pow(2).sum(1).sqrt()
|
norm_x = x.pow(2).sum(1).sqrt()
|
||||||
norm_y = y.pow(2).sum(1).sqrt()
|
norm_y = y.pow(2).sum(1).sqrt()
|
||||||
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
||||||
|
@ -20,7 +20,7 @@ class Dataset(torch.utils.data.Dataset):
|
|||||||
_repr_indent = 2
|
_repr_indent = 2
|
||||||
|
|
||||||
def __init__(self, root):
|
def __init__(self, root):
|
||||||
if isinstance(root, torch._six.string_classes):
|
if isinstance(root, str):
|
||||||
root = os.path.expanduser(root)
|
root = os.path.expanduser(root)
|
||||||
self.root = root
|
self.root = root
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
Optional,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -18,7 +19,7 @@ def generate_mesh(
|
|||||||
maxima: torch.TensorType,
|
maxima: torch.TensorType,
|
||||||
border: float = 1.0,
|
border: float = 1.0,
|
||||||
resolution: int = 100,
|
resolution: int = 100,
|
||||||
device: torch.device = None,
|
device: Optional[torch.device] = None,
|
||||||
):
|
):
|
||||||
# Apply Border
|
# Apply Border
|
||||||
ptp = maxima - minima
|
ptp = maxima - minima
|
||||||
@ -55,14 +56,15 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
|||||||
|
|
||||||
|
|
||||||
def distribution_from_list(list_dist: List[int],
|
def distribution_from_list(list_dist: List[int],
|
||||||
clabels: Iterable[int] = None):
|
clabels: Optional[Iterable[int]] = None):
|
||||||
clabels = clabels or list(range(len(list_dist)))
|
clabels = clabels or list(range(len(list_dist)))
|
||||||
distribution = dict(zip(clabels, list_dist))
|
distribution = dict(zip(clabels, list_dist))
|
||||||
return distribution
|
return distribution
|
||||||
|
|
||||||
|
|
||||||
def parse_distribution(user_distribution,
|
def parse_distribution(
|
||||||
clabels: Iterable[int] = None) -> Dict[int, int]:
|
user_distribution,
|
||||||
|
clabels: Optional[Iterable[int]] = None) -> Dict[int, int]:
|
||||||
"""Parse user-provided distribution.
|
"""Parse user-provided distribution.
|
||||||
|
|
||||||
Return a dictionary with integer keys that represent the class labels and
|
Return a dictionary with integer keys that represent the class labels and
|
||||||
|
16
setup.py
16
setup.py
@ -15,14 +15,14 @@ from setuptools import find_packages, setup
|
|||||||
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
||||||
|
|
||||||
with open("README.md", "r") as fh:
|
with open("README.md", encoding="utf-8") as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
INSTALL_REQUIRES = [
|
INSTALL_REQUIRES = [
|
||||||
"torch>=1.3.1",
|
"torch>=2.0.0",
|
||||||
"torchvision>=0.7.4",
|
"torchvision",
|
||||||
"numpy>=1.9.1",
|
"numpy",
|
||||||
"sklearn",
|
"scikit-learn",
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
]
|
]
|
||||||
DATASETS = [
|
DATASETS = [
|
||||||
@ -51,7 +51,7 @@ ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="prototorch",
|
name="prototorch",
|
||||||
version="0.7.4",
|
version="0.7.6",
|
||||||
description="Highly extensible, GPU-supported "
|
description="Highly extensible, GPU-supported "
|
||||||
"Learning Vector Quantization (LVQ) toolbox "
|
"Learning Vector Quantization (LVQ) toolbox "
|
||||||
"built using PyTorch and its nn API.",
|
"built using PyTorch and its nn API.",
|
||||||
@ -62,7 +62,7 @@ setup(
|
|||||||
url=PROJECT_URL,
|
url=PROJECT_URL,
|
||||||
download_url=DOWNLOAD_URL,
|
download_url=DOWNLOAD_URL,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
python_requires=">=3.7",
|
python_requires=">=3.8",
|
||||||
install_requires=INSTALL_REQUIRES,
|
install_requires=INSTALL_REQUIRES,
|
||||||
extras_require={
|
extras_require={
|
||||||
"datasets": DATASETS,
|
"datasets": DATASETS,
|
||||||
@ -85,10 +85,10 @@ setup(
|
|||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.7",
|
|
||||||
"Programming Language :: Python :: 3.8",
|
"Programming Language :: Python :: 3.8",
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
],
|
],
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""ProtoTorch datasets test suite"""
|
"""ProtoTorch datasets test suite"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
Loading…
Reference in New Issue
Block a user