Compare commits
46 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
391473adf3 | ||
|
0d8db31ff2 | ||
|
89b96f0a98 | ||
|
ee4cf583e3 | ||
|
6ed1b9a832 | ||
|
4a7d4a3d99 | ||
|
0626af207f | ||
|
7b23983887 | ||
|
0649d5bb45 | ||
|
339316aa7e | ||
|
2a85c94b55 | ||
|
6714cb7915 | ||
|
a501ab6c3b | ||
|
37add944b1 | ||
|
0d10fc7e25 | ||
|
71a2e74eff | ||
|
85f75bb28c | ||
|
46ff1c4eb1 | ||
|
ed5b9b6c62 | ||
|
08b3f9bbb9 | ||
|
784a963527 | ||
|
236cbbc4d2 | ||
|
695559fd4a | ||
|
a54acdef22 | ||
|
bebd13868f | ||
|
62df3c0457 | ||
|
cce76c7940 | ||
|
ca24422ab0 | ||
|
a28601751e | ||
|
07a2d6caaa | ||
|
3d3d27fbab | ||
|
b49b7a2d41 | ||
|
b6e8242383 | ||
|
cd616d11b9 | ||
|
afcfcb8973 | ||
|
bf03a45475 | ||
|
083b5c1597 | ||
|
7f0a8e9bce | ||
|
bf09ff8f7f | ||
|
c1d7cfee8f | ||
|
99be965581 | ||
|
fdb9a7c66d | ||
|
eb79b703d8 | ||
|
bc9a826b7d | ||
|
cfe09ec06b | ||
|
3d76dffe3c |
@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.6.0
|
current_version = 0.7.6
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
|
15
.codacy.yml
15
.codacy.yml
@ -1,15 +0,0 @@
|
|||||||
# To validate the contents of your configuration file
|
|
||||||
# run the following command in the folder where the configuration file is located:
|
|
||||||
# codacy-analysis-cli validate-configuration --directory `pwd`
|
|
||||||
# To analyse, run:
|
|
||||||
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
|
|
||||||
---
|
|
||||||
engines:
|
|
||||||
pylintpython3:
|
|
||||||
exclude_paths:
|
|
||||||
- config/engines.yml
|
|
||||||
remark-lint:
|
|
||||||
exclude_paths:
|
|
||||||
- config/engines.yml
|
|
||||||
exclude_paths:
|
|
||||||
- 'tests/**'
|
|
@ -1,2 +0,0 @@
|
|||||||
comment:
|
|
||||||
require_changes: yes
|
|
19
.github/ISSUE_TEMPLATE/bug_report.md
vendored
19
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@ -10,21 +10,28 @@ assignees: ''
|
|||||||
**Describe the bug**
|
**Describe the bug**
|
||||||
A clear and concise description of what the bug is.
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
**To Reproduce**
|
**Steps to reproduce the behavior**
|
||||||
Steps to reproduce the behavior:
|
1. ...
|
||||||
1. Install Prototorch by running '...'
|
2. Run script '...' or this snippet:
|
||||||
2. Run script '...'
|
```python
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
...
|
||||||
|
```
|
||||||
3. See errors
|
3. See errors
|
||||||
|
|
||||||
**Expected behavior**
|
**Expected behavior**
|
||||||
A clear and concise description of what you expected to happen.
|
A clear and concise description of what you expected to happen.
|
||||||
|
|
||||||
|
**Observed behavior**
|
||||||
|
A clear and concise description of what actually happened.
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
If applicable, add screenshots to help explain your problem.
|
If applicable, add screenshots to help explain your problem.
|
||||||
|
|
||||||
**Desktop (please complete the following information):**
|
**System and version information**
|
||||||
- OS: [e.g. Ubuntu 20.10]
|
- OS: [e.g. Ubuntu 20.10]
|
||||||
- Prototorch Version: [e.g. v0.4.0]
|
- ProtoTorch Version: [e.g. 0.4.0]
|
||||||
- Python Version: [e.g. 3.9.5]
|
- Python Version: [e.g. 3.9.5]
|
||||||
|
|
||||||
**Additional context**
|
**Additional context**
|
||||||
|
88
.github/workflows/pythonapp.yml
vendored
88
.github/workflows/pythonapp.yml
vendored
@ -5,33 +5,71 @@ name: tests
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ master, dev ]
|
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ master ]
|
branches: [master]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
style:
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.9
|
- name: Set up Python 3.11
|
||||||
uses: actions/setup-python@v1
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: 3.9
|
python-version: "3.11"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .[all]
|
pip install .[all]
|
||||||
- name: Lint with flake8
|
- uses: pre-commit/action@v3.0.0
|
||||||
run: |
|
compatibility:
|
||||||
pip install flake8
|
needs: style
|
||||||
# stop the build if there are Python syntax errors or undefined names
|
strategy:
|
||||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
fail-fast: false
|
||||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
matrix:
|
||||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
|
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||||
- name: Test with pytest
|
os: [ubuntu-latest, windows-latest]
|
||||||
run: |
|
exclude:
|
||||||
pip install pytest
|
- os: windows-latest
|
||||||
pytest
|
python-version: "3.8"
|
||||||
|
- os: windows-latest
|
||||||
|
python-version: "3.9"
|
||||||
|
- os: windows-latest
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install .[all]
|
||||||
|
- name: Test with pytest
|
||||||
|
run: |
|
||||||
|
pytest
|
||||||
|
publish_pypi:
|
||||||
|
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
|
||||||
|
needs: compatibility
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python 3.10
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install .[all]
|
||||||
|
pip install wheel
|
||||||
|
- name: Build package
|
||||||
|
run: python setup.py sdist bdist_wheel
|
||||||
|
- name: Publish a Python distribution to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
|
with:
|
||||||
|
user: __token__
|
||||||
|
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.0.1
|
rev: v4.4.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
@ -13,36 +13,36 @@ repos:
|
|||||||
- id: check-case-conflict
|
- id: check-case-conflict
|
||||||
|
|
||||||
- repo: https://github.com/myint/autoflake
|
- repo: https://github.com/myint/autoflake
|
||||||
rev: v1.4
|
rev: v2.1.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: autoflake
|
- id: autoflake
|
||||||
|
|
||||||
- repo: http://github.com/PyCQA/isort
|
- repo: http://github.com/PyCQA/isort
|
||||||
rev: 5.8.0
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v0.902
|
rev: v1.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
files: prototorch
|
files: prototorch
|
||||||
additional_dependencies: [types-pkg_resources]
|
additional_dependencies: [types-pkg_resources]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||||
rev: v0.31.0
|
rev: v0.32.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||||
rev: v1.9.0
|
rev: v1.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: python-use-type-annotations
|
- id: python-use-type-annotations
|
||||||
- id: python-no-log-warn
|
- id: python-no-log-warn
|
||||||
- id: python-check-blanket-noqa
|
- id: python-check-blanket-noqa
|
||||||
|
|
||||||
- repo: https://github.com/asottile/pyupgrade
|
- repo: https://github.com/asottile/pyupgrade
|
||||||
rev: v2.19.4
|
rev: v3.7.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: pyupgrade
|
- id: pyupgrade
|
||||||
|
|
||||||
|
36
.travis.yml
36
.travis.yml
@ -1,36 +0,0 @@
|
|||||||
dist: bionic
|
|
||||||
sudo: false
|
|
||||||
language: python
|
|
||||||
python: 3.9
|
|
||||||
cache:
|
|
||||||
directories:
|
|
||||||
- "$HOME/.cache/pip"
|
|
||||||
- "./tests/artifacts"
|
|
||||||
- "$HOME/datasets"
|
|
||||||
install:
|
|
||||||
- pip install .[all] --progress-bar off
|
|
||||||
|
|
||||||
# Generate code coverage report
|
|
||||||
script:
|
|
||||||
- coverage run -m pytest
|
|
||||||
|
|
||||||
# Push the results to codecov
|
|
||||||
after_success:
|
|
||||||
- bash <(curl -s https://codecov.io/bash)
|
|
||||||
|
|
||||||
# Publish on PyPI
|
|
||||||
deploy:
|
|
||||||
provider: pypi
|
|
||||||
username: __token__
|
|
||||||
password:
|
|
||||||
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
|
||||||
on:
|
|
||||||
tags: true
|
|
||||||
skip_existing: true
|
|
||||||
|
|
||||||
# The password is encrypted with:
|
|
||||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
|
||||||
# See https://docs.travis-ci.com/user/deployment/pypi and
|
|
||||||
# https://github.com/travis-ci/travis.rb#installation
|
|
||||||
# for more details
|
|
||||||
# Note: The encrypt command does not work well in ZSH.
|
|
3
LICENSE
3
LICENSE
@ -1,6 +1,7 @@
|
|||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2020 si-cim
|
Copyright (c) 2020 Saxon Institute for Computational Intelligence and Machine
|
||||||
|
Learning (SICIM)
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
@ -2,12 +2,9 @@
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
[](https://travis-ci.com/github/si-cim/prototorch)
|
|
||||||

|

|
||||||
[](https://github.com/si-cim/prototorch/releases)
|
[](https://github.com/si-cim/prototorch/releases)
|
||||||
[](https://pypi.org/project/prototorch/)
|
[](https://pypi.org/project/prototorch/)
|
||||||
[](https://codecov.io/gh/si-cim/prototorch)
|
|
||||||
[](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&utm_medium=referral&utm_content=si-cim/prototorch&utm_campaign=Badge_Grade)
|
|
||||||
[](https://github.com/si-cim/prototorch/blob/master/LICENSE)
|
[](https://github.com/si-cim/prototorch/blob/master/LICENSE)
|
||||||
|
|
||||||
*Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow)
|
*Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow)
|
||||||
|
@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.6.0"
|
release = "0.7.6"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ html_css_files = [
|
|||||||
# -- Options for HTMLHelp output ------------------------------------------
|
# -- Options for HTMLHelp output ------------------------------------------
|
||||||
|
|
||||||
# Output file base name for HTML help builder.
|
# Output file base name for HTML help builder.
|
||||||
htmlhelp_basename = "protoflowdoc"
|
htmlhelp_basename = "prototorchdoc"
|
||||||
|
|
||||||
# -- Options for LaTeX output ---------------------------------------------
|
# -- Options for LaTeX output ---------------------------------------------
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""ProtoTorch CBC example using 2D Iris data."""
|
"""ProtoTorch CBC example using 2D Iris data."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
@ -7,6 +9,7 @@ import prototorch as pt
|
|||||||
|
|
||||||
|
|
||||||
class CBC(torch.nn.Module):
|
class CBC(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, data, **kwargs):
|
def __init__(self, data, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.components_layer = pt.components.ReasoningComponents(
|
self.components_layer = pt.components.ReasoningComponents(
|
||||||
@ -23,6 +26,7 @@ class CBC(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class VisCBC2D():
|
class VisCBC2D():
|
||||||
|
|
||||||
def __init__(self, model, data):
|
def __init__(self, model, data):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.x_train, self.y_train = pt.utils.parse_data_arg(data)
|
self.x_train, self.y_train = pt.utils.parse_data_arg(data)
|
||||||
@ -32,7 +36,7 @@ class VisCBC2D():
|
|||||||
self.resolution = 100
|
self.resolution = 100
|
||||||
self.cmap = "viridis"
|
self.cmap = "viridis"
|
||||||
|
|
||||||
def on_epoch_end(self):
|
def on_train_epoch_end(self):
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
_components = self.model.components_layer._components.detach()
|
_components = self.model.components_layer._components.detach()
|
||||||
ax = self.fig.gca()
|
ax = self.fig.gca()
|
||||||
@ -92,5 +96,5 @@ if __name__ == "__main__":
|
|||||||
correct += (y_pred.argmax(1) == y).float().sum(0)
|
correct += (y_pred.argmax(1) == y).float().sum(0)
|
||||||
|
|
||||||
acc = 100 * correct / len(train_ds)
|
acc = 100 * correct / len(train_ds)
|
||||||
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
logging.info(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
||||||
vis.on_epoch_end()
|
vis.on_train_epoch_end()
|
||||||
|
76
examples/gmlvq.py
Normal file
76
examples/gmlvq.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
"""ProtoTorch GMLVQ example using Iris data."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
|
||||||
|
class GMLVQ(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of Generalized Matrix Learning Vector Quantization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.components_layer = pt.components.LabeledComponents(
|
||||||
|
distribution=[1, 1, 1],
|
||||||
|
components_initializer=pt.initializers.SMCI(data, noise=0.1),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.backbone = pt.transforms.Omega(
|
||||||
|
len(data[0][0]),
|
||||||
|
len(data[0][0]),
|
||||||
|
pt.initializers.RandomLinearTransformInitializer(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
"""
|
||||||
|
Forward function that returns a tuple of dissimilarities and label information.
|
||||||
|
Feed into GLVQLoss to get a complete GMLVQ model.
|
||||||
|
"""
|
||||||
|
components, label = self.components_layer()
|
||||||
|
|
||||||
|
latent_x = self.backbone(data)
|
||||||
|
latent_components = self.backbone(components)
|
||||||
|
|
||||||
|
distance = pt.distances.squared_euclidean_distance(
|
||||||
|
latent_x, latent_components)
|
||||||
|
|
||||||
|
return distance, label
|
||||||
|
|
||||||
|
def predict(self, data):
|
||||||
|
"""
|
||||||
|
The GMLVQ has a modified prediction step, where a competition layer is applied.
|
||||||
|
"""
|
||||||
|
components, label = self.components_layer()
|
||||||
|
distance = pt.distances.squared_euclidean_distance(data, components)
|
||||||
|
winning_label = pt.competitions.wtac(distance, label)
|
||||||
|
return winning_label
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_ds = pt.datasets.Iris()
|
||||||
|
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
|
||||||
|
|
||||||
|
model = GMLVQ(train_ds)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
|
||||||
|
criterion = pt.losses.GLVQLoss()
|
||||||
|
|
||||||
|
for epoch in range(200):
|
||||||
|
correct = 0.0
|
||||||
|
for x, y in train_loader:
|
||||||
|
d, labels = model(x)
|
||||||
|
loss = criterion(d, y, labels).mean(0)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
y_pred = model.predict(x)
|
||||||
|
correct += (y_pred == y).float().sum(0)
|
||||||
|
|
||||||
|
acc = 100 * correct / len(train_ds)
|
||||||
|
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
@ -1,28 +1,23 @@
|
|||||||
"""ProtoTorch package"""
|
"""ProtoTorch package"""
|
||||||
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
from . import (
|
from . import datasets # noqa: F401
|
||||||
datasets,
|
from . import nn # noqa: F401
|
||||||
nn,
|
from . import utils # noqa: F401
|
||||||
utils,
|
from .core import competitions # noqa: F401
|
||||||
)
|
from .core import components # noqa: F401
|
||||||
from .core import (
|
from .core import distances # noqa: F401
|
||||||
competitions,
|
from .core import initializers # noqa: F401
|
||||||
components,
|
from .core import losses # noqa: F401
|
||||||
distances,
|
from .core import pooling # noqa: F401
|
||||||
initializers,
|
from .core import similarities # noqa: F401
|
||||||
losses,
|
from .core import transforms # noqa: F401
|
||||||
pooling,
|
|
||||||
similarities,
|
|
||||||
transforms,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Core Setup
|
# Core Setup
|
||||||
__version__ = "0.6.0"
|
__version__ = "0.7.6"
|
||||||
|
|
||||||
__all_core__ = [
|
__all_core__ = [
|
||||||
"competitions",
|
"competitions",
|
||||||
@ -40,7 +35,7 @@ __all_core__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Plugin Loader
|
# Plugin Loader
|
||||||
__path__: List[str] = pkgutil.extend_path(__path__, __name__)
|
__path__ = pkgutil.extend_path(__path__, __name__)
|
||||||
|
|
||||||
|
|
||||||
def discover_plugins():
|
def discover_plugins():
|
||||||
|
@ -38,7 +38,7 @@ def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
|
|||||||
pk = A
|
pk = A
|
||||||
nk = (1 - A) * B
|
nk = (1 - A) * B
|
||||||
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
||||||
probs = numerator / (pk + nk).sum(1)
|
probs = numerator / ((pk + nk).sum(1) + 1e-8)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +48,8 @@ class WTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, distances, labels):
|
|
||||||
|
def forward(self, distances, labels): # pylint: disable=no-self-use
|
||||||
return wtac(distances, labels)
|
return wtac(distances, labels)
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +59,8 @@ class LTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, probs, labels):
|
|
||||||
|
def forward(self, probs, labels): # pylint: disable=no-self-use
|
||||||
return wtac(-1.0 * probs, labels)
|
return wtac(-1.0 * probs, labels)
|
||||||
|
|
||||||
|
|
||||||
@ -68,6 +70,7 @@ class KNNC(torch.nn.Module):
|
|||||||
Thin wrapper over the `knnc` function.
|
Thin wrapper over the `knnc` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, k=1, **kwargs):
|
def __init__(self, k=1, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.k = k
|
self.k = k
|
||||||
@ -85,5 +88,6 @@ class CBCC(torch.nn.Module):
|
|||||||
Thin wrapper over the `cbcc` function.
|
Thin wrapper over the `cbcc` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, detections, reasonings):
|
|
||||||
|
def forward(self, detections, reasonings): # pylint: disable=no-self-use
|
||||||
return cbcc(detections, reasonings)
|
return cbcc(detections, reasonings)
|
||||||
|
@ -6,7 +6,8 @@ from typing import Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from ..utils import parse_distribution
|
from prototorch.utils import parse_distribution
|
||||||
|
|
||||||
from .initializers import (
|
from .initializers import (
|
||||||
AbstractClassAwareCompInitializer,
|
AbstractClassAwareCompInitializer,
|
||||||
AbstractComponentsInitializer,
|
AbstractComponentsInitializer,
|
||||||
@ -63,6 +64,7 @@ def get_cikwargs(init, distribution):
|
|||||||
|
|
||||||
class AbstractComponents(torch.nn.Module):
|
class AbstractComponents(torch.nn.Module):
|
||||||
"""Abstract class for all components modules."""
|
"""Abstract class for all components modules."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_components(self):
|
def num_components(self):
|
||||||
"""Current number of components."""
|
"""Current number of components."""
|
||||||
@ -85,6 +87,7 @@ class AbstractComponents(torch.nn.Module):
|
|||||||
|
|
||||||
class Components(AbstractComponents):
|
class Components(AbstractComponents):
|
||||||
"""A set of adaptable Tensors."""
|
"""A set of adaptable Tensors."""
|
||||||
|
|
||||||
def __init__(self, num_components: int,
|
def __init__(self, num_components: int,
|
||||||
initializer: AbstractComponentsInitializer):
|
initializer: AbstractComponentsInitializer):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -112,6 +115,7 @@ class Components(AbstractComponents):
|
|||||||
|
|
||||||
class AbstractLabels(torch.nn.Module):
|
class AbstractLabels(torch.nn.Module):
|
||||||
"""Abstract class for all labels modules."""
|
"""Abstract class for all labels modules."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
return self._labels.cpu()
|
return self._labels.cpu()
|
||||||
@ -152,6 +156,7 @@ class AbstractLabels(torch.nn.Module):
|
|||||||
|
|
||||||
class Labels(AbstractLabels):
|
class Labels(AbstractLabels):
|
||||||
"""A set of standalone labels."""
|
"""A set of standalone labels."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
distribution: Union[dict, list, tuple],
|
distribution: Union[dict, list, tuple],
|
||||||
initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
initializer: AbstractLabelsInitializer = LabelsInitializer()):
|
||||||
@ -182,6 +187,7 @@ class Labels(AbstractLabels):
|
|||||||
|
|
||||||
class LabeledComponents(AbstractComponents):
|
class LabeledComponents(AbstractComponents):
|
||||||
"""A set of adaptable components and corresponding unadaptable labels."""
|
"""A set of adaptable components and corresponding unadaptable labels."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
distribution: Union[dict, list, tuple],
|
distribution: Union[dict, list, tuple],
|
||||||
@ -249,12 +255,15 @@ class Reasonings(torch.nn.Module):
|
|||||||
The `reasonings` tensor is of shape [num_components, num_classes, 2].
|
The `reasonings` tensor is of shape [num_components, num_classes, 2].
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
distribution: Union[dict, list, tuple],
|
distribution: Union[dict, list, tuple],
|
||||||
initializer:
|
initializer:
|
||||||
AbstractReasoningsInitializer = RandomReasoningsInitializer()):
|
AbstractReasoningsInitializer = RandomReasoningsInitializer(),
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.add_reasonings(distribution, initializer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
@ -306,6 +315,7 @@ class ReasoningComponents(AbstractComponents):
|
|||||||
three element probability distribution.
|
three element probability distribution.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
distribution: Union[dict, list, tuple],
|
distribution: Union[dict, list, tuple],
|
||||||
|
@ -11,7 +11,7 @@ def squared_euclidean_distance(x, y):
|
|||||||
**Alias:**
|
**Alias:**
|
||||||
``prototorch.functions.distances.sed``
|
``prototorch.functions.distances.sed``
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
expanded_x = x.unsqueeze(dim=1)
|
expanded_x = x.unsqueeze(dim=1)
|
||||||
batchwise_difference = y - expanded_x
|
batchwise_difference = y - expanded_x
|
||||||
differences_raised = torch.pow(batchwise_difference, 2)
|
differences_raised = torch.pow(batchwise_difference, 2)
|
||||||
@ -27,23 +27,20 @@ def euclidean_distance(x, y):
|
|||||||
:returns: Distance Tensor of shape :math:`X \times Y`
|
:returns: Distance Tensor of shape :math:`X \times Y`
|
||||||
:rtype: `torch.tensor`
|
:rtype: `torch.tensor`
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
distances_raised = squared_euclidean_distance(x, y)
|
distances_raised = squared_euclidean_distance(x, y)
|
||||||
distances = torch.sqrt(distances_raised)
|
distances = torch.sqrt(distances_raised)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
def euclidean_distance_v2(x, y):
|
def euclidean_distance_v2(x, y):
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
diff = y - x.unsqueeze(1)
|
diff = y - x.unsqueeze(1)
|
||||||
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
|
||||||
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
|
||||||
# batch diagonal. See:
|
# batch diagonal. See:
|
||||||
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||||
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||||
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
|
||||||
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
|
||||||
# print(f"{distances.shape=}") # (nx, ny)
|
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
@ -57,7 +54,7 @@ def lpnorm_distance(x, y, p):
|
|||||||
|
|
||||||
:param p: p parameter of the lp norm
|
:param p: p parameter of the lp norm
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
distances = torch.cdist(x, y, p=p)
|
distances = torch.cdist(x, y, p=p)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
@ -69,7 +66,7 @@ def omega_distance(x, y, omega):
|
|||||||
|
|
||||||
:param `torch.tensor` omega: Two dimensional matrix
|
:param `torch.tensor` omega: Two dimensional matrix
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
projected_x = x @ omega
|
projected_x = x @ omega
|
||||||
projected_y = y @ omega
|
projected_y = y @ omega
|
||||||
distances = squared_euclidean_distance(projected_x, projected_y)
|
distances = squared_euclidean_distance(projected_x, projected_y)
|
||||||
@ -83,7 +80,7 @@ def lomega_distance(x, y, omegas):
|
|||||||
|
|
||||||
:param `torch.tensor` omegas: Three dimensional matrix
|
:param `torch.tensor` omegas: Three dimensional matrix
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
projected_x = x @ omegas
|
projected_x = x @ omegas
|
||||||
projected_y = torch.diagonal(y @ omegas).T
|
projected_y = torch.diagonal(y @ omegas).T
|
||||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||||
|
@ -11,7 +11,7 @@ from typing import (
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..utils import parse_data_arg, parse_distribution
|
from prototorch.utils import parse_data_arg, parse_distribution
|
||||||
|
|
||||||
|
|
||||||
# Components
|
# Components
|
||||||
@ -26,11 +26,18 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
|
|||||||
Use this to 'generate' pre-initialized components elsewhere.
|
Use this to 'generate' pre-initialized components elsewhere.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, components):
|
def __init__(self, components):
|
||||||
self.components = components
|
self.components = components
|
||||||
|
|
||||||
def generate(self, num_components: int = 0):
|
def generate(self, num_components: int = 0):
|
||||||
"""Ignore `num_components` and simply return `self.components`."""
|
"""Ignore `num_components` and simply return `self.components`."""
|
||||||
|
provided_num_components = len(self.components)
|
||||||
|
if provided_num_components != num_components:
|
||||||
|
wmsg = f"The number of components ({provided_num_components}) " \
|
||||||
|
f"provided to {self.__class__.__name__} " \
|
||||||
|
f"does not match the expected number ({num_components})."
|
||||||
|
warnings.warn(wmsg)
|
||||||
if not isinstance(self.components, torch.Tensor):
|
if not isinstance(self.components, torch.Tensor):
|
||||||
wmsg = f"Converting components to {torch.Tensor}..."
|
wmsg = f"Converting components to {torch.Tensor}..."
|
||||||
warnings.warn(wmsg)
|
warnings.warn(wmsg)
|
||||||
@ -40,6 +47,7 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
|
|||||||
|
|
||||||
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||||
"""Abstract class for all dimension-aware components initializers."""
|
"""Abstract class for all dimension-aware components initializers."""
|
||||||
|
|
||||||
def __init__(self, shape: Union[Iterable, int]):
|
def __init__(self, shape: Union[Iterable, int]):
|
||||||
if isinstance(shape, Iterable):
|
if isinstance(shape, Iterable):
|
||||||
self.component_shape = tuple(shape)
|
self.component_shape = tuple(shape)
|
||||||
@ -53,6 +61,7 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
|
|
||||||
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
||||||
"""Generate zeros corresponding to the components shape."""
|
"""Generate zeros corresponding to the components shape."""
|
||||||
|
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
components = torch.zeros((num_components, ) + self.component_shape)
|
components = torch.zeros((num_components, ) + self.component_shape)
|
||||||
return components
|
return components
|
||||||
@ -60,6 +69,7 @@ class ZerosCompInitializer(ShapeAwareCompInitializer):
|
|||||||
|
|
||||||
class OnesCompInitializer(ShapeAwareCompInitializer):
|
class OnesCompInitializer(ShapeAwareCompInitializer):
|
||||||
"""Generate ones corresponding to the components shape."""
|
"""Generate ones corresponding to the components shape."""
|
||||||
|
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
components = torch.ones((num_components, ) + self.component_shape)
|
components = torch.ones((num_components, ) + self.component_shape)
|
||||||
return components
|
return components
|
||||||
@ -67,6 +77,7 @@ class OnesCompInitializer(ShapeAwareCompInitializer):
|
|||||||
|
|
||||||
class FillValueCompInitializer(OnesCompInitializer):
|
class FillValueCompInitializer(OnesCompInitializer):
|
||||||
"""Generate components with the provided `fill_value`."""
|
"""Generate components with the provided `fill_value`."""
|
||||||
|
|
||||||
def __init__(self, shape, fill_value: float = 1.0):
|
def __init__(self, shape, fill_value: float = 1.0):
|
||||||
super().__init__(shape)
|
super().__init__(shape)
|
||||||
self.fill_value = fill_value
|
self.fill_value = fill_value
|
||||||
@ -79,6 +90,7 @@ class FillValueCompInitializer(OnesCompInitializer):
|
|||||||
|
|
||||||
class UniformCompInitializer(OnesCompInitializer):
|
class UniformCompInitializer(OnesCompInitializer):
|
||||||
"""Generate components by sampling from a continuous uniform distribution."""
|
"""Generate components by sampling from a continuous uniform distribution."""
|
||||||
|
|
||||||
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
|
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
|
||||||
super().__init__(shape)
|
super().__init__(shape)
|
||||||
self.minimum = minimum
|
self.minimum = minimum
|
||||||
@ -93,6 +105,7 @@ class UniformCompInitializer(OnesCompInitializer):
|
|||||||
|
|
||||||
class RandomNormalCompInitializer(OnesCompInitializer):
|
class RandomNormalCompInitializer(OnesCompInitializer):
|
||||||
"""Generate components by sampling from a standard normal distribution."""
|
"""Generate components by sampling from a standard normal distribution."""
|
||||||
|
|
||||||
def __init__(self, shape, shift=0.0, scale=1.0):
|
def __init__(self, shape, shift=0.0, scale=1.0):
|
||||||
super().__init__(shape)
|
super().__init__(shape)
|
||||||
self.shift = shift
|
self.shift = shift
|
||||||
@ -113,6 +126,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
`data` has to be a torch tensor.
|
`data` has to be a torch tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data: torch.Tensor,
|
data: torch.Tensor,
|
||||||
noise: float = 0.0,
|
noise: float = 0.0,
|
||||||
@ -137,6 +151,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
|
|
||||||
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
"""'Generate' the components from the provided data."""
|
"""'Generate' the components from the provided data."""
|
||||||
|
|
||||||
def generate(self, num_components: int = 0):
|
def generate(self, num_components: int = 0):
|
||||||
"""Ignore `num_components` and simply return transformed `self.data`."""
|
"""Ignore `num_components` and simply return transformed `self.data`."""
|
||||||
components = self.generate_end_hook(self.data)
|
components = self.generate_end_hook(self.data)
|
||||||
@ -145,6 +160,7 @@ class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
|||||||
|
|
||||||
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
"""Generate components by uniformly sampling from the provided data."""
|
"""Generate components by uniformly sampling from the provided data."""
|
||||||
|
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
||||||
samples = self.data[indices]
|
samples = self.data[indices]
|
||||||
@ -154,6 +170,7 @@ class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
|||||||
|
|
||||||
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
"""Generate components by computing the mean of the provided data."""
|
"""Generate components by computing the mean of the provided data."""
|
||||||
|
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
mean = self.data.mean(dim=0)
|
mean = self.data.mean(dim=0)
|
||||||
repeat_dim = [num_components] + [1] * len(mean.shape)
|
repeat_dim = [num_components] + [1] * len(mean.shape)
|
||||||
@ -172,6 +189,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
target tensors.
|
target tensors.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data,
|
data,
|
||||||
noise: float = 0.0,
|
noise: float = 0.0,
|
||||||
@ -199,6 +217,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
|
|
||||||
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
||||||
"""'Generate' components from provided data and requested distribution."""
|
"""'Generate' components from provided data and requested distribution."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
"""Ignore `distribution` and simply return transformed `self.data`."""
|
"""Ignore `distribution` and simply return transformed `self.data`."""
|
||||||
components = self.generate_end_hook(self.data)
|
components = self.generate_end_hook(self.data)
|
||||||
@ -207,6 +226,7 @@ class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
|||||||
|
|
||||||
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
||||||
"""Abstract class for all stratified components initializers."""
|
"""Abstract class for all stratified components initializers."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
|
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
|
||||||
@ -217,6 +237,8 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
|||||||
components = torch.tensor([])
|
components = torch.tensor([])
|
||||||
for k, v in distribution.items():
|
for k, v in distribution.items():
|
||||||
stratified_data = self.data[self.targets == k]
|
stratified_data = self.data[self.targets == k]
|
||||||
|
if len(stratified_data) == 0:
|
||||||
|
raise ValueError(f"No data available for class {k}.")
|
||||||
initializer = self.subinit_type(
|
initializer = self.subinit_type(
|
||||||
stratified_data,
|
stratified_data,
|
||||||
noise=self.noise,
|
noise=self.noise,
|
||||||
@ -229,6 +251,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
|||||||
|
|
||||||
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
||||||
"""Generate components using stratified sampling from the provided data."""
|
"""Generate components using stratified sampling from the provided data."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def subinit_type(self):
|
def subinit_type(self):
|
||||||
return SelectionCompInitializer
|
return SelectionCompInitializer
|
||||||
@ -236,6 +259,7 @@ class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
|||||||
|
|
||||||
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
||||||
"""Generate components at stratified means of the provided data."""
|
"""Generate components at stratified means of the provided data."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def subinit_type(self):
|
def subinit_type(self):
|
||||||
return MeanCompInitializer
|
return MeanCompInitializer
|
||||||
@ -244,6 +268,7 @@ class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
|||||||
# Labels
|
# Labels
|
||||||
class AbstractLabelsInitializer(ABC):
|
class AbstractLabelsInitializer(ABC):
|
||||||
"""Abstract class for all labels initializers."""
|
"""Abstract class for all labels initializers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
...
|
...
|
||||||
@ -255,6 +280,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
|||||||
Use this to 'generate' pre-initialized labels elsewhere.
|
Use this to 'generate' pre-initialized labels elsewhere.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, labels):
|
def __init__(self, labels):
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
|
|
||||||
@ -273,6 +299,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
|||||||
|
|
||||||
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
||||||
"""'Generate' the labels from a torch Dataset."""
|
"""'Generate' the labels from a torch Dataset."""
|
||||||
|
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
self.data, self.targets = parse_data_arg(data)
|
self.data, self.targets = parse_data_arg(data)
|
||||||
|
|
||||||
@ -283,6 +310,7 @@ class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
|||||||
|
|
||||||
class LabelsInitializer(AbstractLabelsInitializer):
|
class LabelsInitializer(AbstractLabelsInitializer):
|
||||||
"""Generate labels from `distribution`."""
|
"""Generate labels from `distribution`."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
distribution = parse_distribution(distribution)
|
distribution = parse_distribution(distribution)
|
||||||
labels_list = []
|
labels_list = []
|
||||||
@ -294,6 +322,7 @@ class LabelsInitializer(AbstractLabelsInitializer):
|
|||||||
|
|
||||||
class OneHotLabelsInitializer(LabelsInitializer):
|
class OneHotLabelsInitializer(LabelsInitializer):
|
||||||
"""Generate one-hot-encoded labels from `distribution`."""
|
"""Generate one-hot-encoded labels from `distribution`."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
distribution = parse_distribution(distribution)
|
distribution = parse_distribution(distribution)
|
||||||
num_classes = len(distribution.keys())
|
num_classes = len(distribution.keys())
|
||||||
@ -303,17 +332,19 @@ class OneHotLabelsInitializer(LabelsInitializer):
|
|||||||
|
|
||||||
|
|
||||||
# Reasonings
|
# Reasonings
|
||||||
|
def compute_distribution_shape(distribution):
|
||||||
|
distribution = parse_distribution(distribution)
|
||||||
|
num_components = sum(distribution.values())
|
||||||
|
num_classes = len(distribution.keys())
|
||||||
|
return (num_components, num_classes, 2)
|
||||||
|
|
||||||
|
|
||||||
class AbstractReasoningsInitializer(ABC):
|
class AbstractReasoningsInitializer(ABC):
|
||||||
"""Abstract class for all reasonings initializers."""
|
"""Abstract class for all reasonings initializers."""
|
||||||
|
|
||||||
def __init__(self, components_first: bool = True):
|
def __init__(self, components_first: bool = True):
|
||||||
self.components_first = components_first
|
self.components_first = components_first
|
||||||
|
|
||||||
def compute_shape(self, distribution):
|
|
||||||
distribution = parse_distribution(distribution)
|
|
||||||
num_components = sum(distribution.values())
|
|
||||||
num_classes = len(distribution.keys())
|
|
||||||
return (num_components, num_classes, 2)
|
|
||||||
|
|
||||||
def generate_end_hook(self, reasonings):
|
def generate_end_hook(self, reasonings):
|
||||||
if not self.components_first:
|
if not self.components_first:
|
||||||
reasonings = reasonings.permute(2, 1, 0)
|
reasonings = reasonings.permute(2, 1, 0)
|
||||||
@ -331,6 +362,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
Use this to 'generate' pre-initialized reasonings elsewhere.
|
Use this to 'generate' pre-initialized reasonings elsewhere.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reasonings, **kwargs):
|
def __init__(self, reasonings, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.reasonings = reasonings
|
self.reasonings = reasonings
|
||||||
@ -348,8 +380,9 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
|
|
||||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with zeros."""
|
"""Reasonings are all initialized with zeros."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.zeros(*shape)
|
reasonings = torch.zeros(*shape)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@ -357,8 +390,9 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
|
|
||||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with ones."""
|
"""Reasonings are all initialized with ones."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.ones(*shape)
|
reasonings = torch.ones(*shape)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@ -366,13 +400,14 @@ class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
|
|
||||||
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are randomly initialized."""
|
"""Reasonings are randomly initialized."""
|
||||||
|
|
||||||
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
|
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.minimum = minimum
|
self.minimum = minimum
|
||||||
self.maximum = maximum
|
self.maximum = maximum
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@ -380,8 +415,10 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
|
|
||||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Each component reasons positively for exactly one class."""
|
"""Each component reasons positively for exactly one class."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
num_components, num_classes, _ = compute_distribution_shape(
|
||||||
|
distribution)
|
||||||
A = OneHotLabelsInitializer().generate(distribution)
|
A = OneHotLabelsInitializer().generate(distribution)
|
||||||
B = torch.zeros(num_components, num_classes)
|
B = torch.zeros(num_components, num_classes)
|
||||||
reasonings = torch.stack([A, B], dim=-1)
|
reasonings = torch.stack([A, B], dim=-1)
|
||||||
@ -397,6 +434,7 @@ class AbstractTransformInitializer(ABC):
|
|||||||
|
|
||||||
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
||||||
"""Abstract class for all linear transform initializers."""
|
"""Abstract class for all linear transform initializers."""
|
||||||
|
|
||||||
def __init__(self, out_dim_first: bool = False):
|
def __init__(self, out_dim_first: bool = False):
|
||||||
self.out_dim_first = out_dim_first
|
self.out_dim_first = out_dim_first
|
||||||
|
|
||||||
@ -413,6 +451,7 @@ class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
|||||||
|
|
||||||
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||||
"""Initialize a matrix with zeros."""
|
"""Initialize a matrix with zeros."""
|
||||||
|
|
||||||
def generate(self, in_dim: int, out_dim: int):
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
weights = torch.zeros(in_dim, out_dim)
|
weights = torch.zeros(in_dim, out_dim)
|
||||||
return self.generate_end_hook(weights)
|
return self.generate_end_hook(weights)
|
||||||
@ -420,13 +459,23 @@ class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
|||||||
|
|
||||||
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||||
"""Initialize a matrix with ones."""
|
"""Initialize a matrix with ones."""
|
||||||
|
|
||||||
def generate(self, in_dim: int, out_dim: int):
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
weights = torch.ones(in_dim, out_dim)
|
weights = torch.ones(in_dim, out_dim)
|
||||||
return self.generate_end_hook(weights)
|
return self.generate_end_hook(weights)
|
||||||
|
|
||||||
|
|
||||||
class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
class RandomLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||||
|
"""Initialize a matrix with random values."""
|
||||||
|
|
||||||
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
|
weights = torch.rand(in_dim, out_dim)
|
||||||
|
return self.generate_end_hook(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class EyeLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||||
"""Initialize a matrix with the largest possible identity matrix."""
|
"""Initialize a matrix with the largest possible identity matrix."""
|
||||||
|
|
||||||
def generate(self, in_dim: int, out_dim: int):
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
weights = torch.zeros(in_dim, out_dim)
|
weights = torch.zeros(in_dim, out_dim)
|
||||||
I = torch.eye(min(in_dim, out_dim))
|
I = torch.eye(min(in_dim, out_dim))
|
||||||
@ -436,6 +485,7 @@ class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
|||||||
|
|
||||||
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
||||||
"""Abstract class for all data-aware linear transform initializers."""
|
"""Abstract class for all data-aware linear transform initializers."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data: torch.Tensor,
|
data: torch.Tensor,
|
||||||
noise: float = 0.0,
|
noise: float = 0.0,
|
||||||
@ -456,11 +506,19 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
|||||||
|
|
||||||
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
|
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
|
||||||
"""Initialize a matrix with Eigenvectors from the data."""
|
"""Initialize a matrix with Eigenvectors from the data."""
|
||||||
|
|
||||||
def generate(self, in_dim: int, out_dim: int):
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
|
_, _, weights = torch.pca_lowrank(self.data, q=out_dim)
|
||||||
return self.generate_end_hook(weights)
|
return self.generate_end_hook(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class LiteralLinearTransformInitializer(AbstractDataAwareLTInitializer):
|
||||||
|
"""'Generate' the provided weights."""
|
||||||
|
|
||||||
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
|
return self.generate_end_hook(self.data)
|
||||||
|
|
||||||
|
|
||||||
# Aliases - Components
|
# Aliases - Components
|
||||||
CACI = ClassAwareCompInitializer
|
CACI = ClassAwareCompInitializer
|
||||||
DACI = DataAwareCompInitializer
|
DACI = DataAwareCompInitializer
|
||||||
@ -489,7 +547,9 @@ RRI = RandomReasoningsInitializer
|
|||||||
ZRI = ZerosReasoningsInitializer
|
ZRI = ZerosReasoningsInitializer
|
||||||
|
|
||||||
# Aliases - Transforms
|
# Aliases - Transforms
|
||||||
Eye = EyeTransformInitializer
|
ELTI = Eye = EyeLinearTransformInitializer
|
||||||
OLTI = OnesLinearTransformInitializer
|
OLTI = OnesLinearTransformInitializer
|
||||||
|
RLTI = RandomLinearTransformInitializer
|
||||||
ZLTI = ZerosLinearTransformInitializer
|
ZLTI = ZerosLinearTransformInitializer
|
||||||
PCALTI = PCALinearTransformInitializer
|
PCALTI = PCALinearTransformInitializer
|
||||||
|
LLTI = LiteralLinearTransformInitializer
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..nn.activations import get_activation
|
from prototorch.nn.activations import get_activation
|
||||||
|
|
||||||
|
|
||||||
# Helpers
|
# Helpers
|
||||||
@ -106,20 +106,31 @@ def margin_loss(y_pred, y_true, margin=0.3):
|
|||||||
|
|
||||||
|
|
||||||
class GLVQLoss(torch.nn.Module):
|
class GLVQLoss(torch.nn.Module):
|
||||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
|
||||||
|
def __init__(self,
|
||||||
|
margin=0.0,
|
||||||
|
transfer_fn="identity",
|
||||||
|
beta=10,
|
||||||
|
add_dp=False,
|
||||||
|
**kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.squashing = get_activation(squashing)
|
self.transfer_fn = get_activation(transfer_fn)
|
||||||
self.beta = torch.tensor(beta)
|
self.beta = torch.tensor(beta)
|
||||||
|
self.add_dp = add_dp
|
||||||
|
|
||||||
def forward(self, outputs, targets):
|
def forward(self, outputs, targets, plabels):
|
||||||
distances, plabels = outputs
|
# mu = glvq_loss(outputs, targets, plabels)
|
||||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
dp, dm = _get_dp_dm(outputs, targets, plabels)
|
||||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
mu = (dp - dm) / (dp + dm)
|
||||||
return torch.sum(batch_loss, dim=0)
|
if self.add_dp:
|
||||||
|
mu = mu + dp
|
||||||
|
batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
|
||||||
|
return batch_loss.sum()
|
||||||
|
|
||||||
|
|
||||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
margin=0.3,
|
margin=0.3,
|
||||||
size_average=None,
|
size_average=None,
|
||||||
@ -133,6 +144,7 @@ class MarginLoss(torch.nn.modules.loss._Loss):
|
|||||||
|
|
||||||
|
|
||||||
class NeuralGasEnergy(torch.nn.Module):
|
class NeuralGasEnergy(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, lm, **kwargs):
|
def __init__(self, lm, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.lm = lm
|
self.lm = lm
|
||||||
@ -153,6 +165,7 @@ class NeuralGasEnergy(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
||||||
|
|
||||||
def __init__(self, topology_layer, **kwargs):
|
def __init__(self, topology_layer, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.topology_layer = topology_layer
|
self.topology_layer = topology_layer
|
||||||
|
@ -82,23 +82,27 @@ def stratified_prod_pooling(values: torch.Tensor,
|
|||||||
|
|
||||||
class StratifiedSumPooling(torch.nn.Module):
|
class StratifiedSumPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||||
def forward(self, values, labels):
|
|
||||||
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_sum_pooling(values, labels)
|
return stratified_sum_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedProdPooling(torch.nn.Module):
|
class StratifiedProdPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||||
def forward(self, values, labels):
|
|
||||||
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_prod_pooling(values, labels)
|
return stratified_prod_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMinPooling(torch.nn.Module):
|
class StratifiedMinPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||||
def forward(self, values, labels):
|
|
||||||
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_min_pooling(values, labels)
|
return stratified_min_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMaxPooling(torch.nn.Module):
|
class StratifiedMaxPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||||
def forward(self, values, labels):
|
|
||||||
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_max_pooling(values, labels)
|
return stratified_max_pooling(values, labels)
|
||||||
|
@ -21,7 +21,7 @@ def cosine_similarity(x, y):
|
|||||||
Expected dimension of x is 2.
|
Expected dimension of x is 2.
|
||||||
Expected dimension of y is 2.
|
Expected dimension of y is 2.
|
||||||
"""
|
"""
|
||||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
|
||||||
norm_x = x.pow(2).sum(1).sqrt()
|
norm_x = x.pow(2).sum(1).sqrt()
|
||||||
norm_y = y.pow(2).sum(1).sqrt()
|
norm_y = y.pow(2).sum(1).sqrt()
|
||||||
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T
|
||||||
|
@ -5,17 +5,18 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from .initializers import (
|
from .initializers import (
|
||||||
AbstractLinearTransformInitializer,
|
AbstractLinearTransformInitializer,
|
||||||
EyeTransformInitializer,
|
EyeLinearTransformInitializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LinearTransform(torch.nn.Module):
|
class LinearTransform(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_dim: int,
|
in_dim: int,
|
||||||
out_dim: int,
|
out_dim: int,
|
||||||
initializer:
|
initializer:
|
||||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
AbstractLinearTransformInitializer = EyeLinearTransformInitializer()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.set_weights(in_dim, out_dim, initializer)
|
self.set_weights(in_dim, out_dim, initializer)
|
||||||
|
|
||||||
@ -31,12 +32,15 @@ class LinearTransform(torch.nn.Module):
|
|||||||
in_dim: int,
|
in_dim: int,
|
||||||
out_dim: int,
|
out_dim: int,
|
||||||
initializer:
|
initializer:
|
||||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
AbstractLinearTransformInitializer = EyeLinearTransformInitializer()):
|
||||||
weights = initializer.generate(in_dim, out_dim)
|
weights = initializer.generate(in_dim, out_dim)
|
||||||
self._register_weights(weights)
|
self._register_weights(weights)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x @ self.weights.T
|
return x @ self._weights
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"weights: (shape: {tuple(self._weights.shape)})"
|
||||||
|
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""ProtoTorch datasets"""
|
"""ProtoTorch datasets"""
|
||||||
|
|
||||||
from .abstract import NumpyDataset
|
from .abstract import CSVDataset, NumpyDataset
|
||||||
from .sklearn import (
|
from .sklearn import (
|
||||||
Blobs,
|
Blobs,
|
||||||
Circles,
|
Circles,
|
||||||
@ -10,3 +10,4 @@ from .sklearn import (
|
|||||||
)
|
)
|
||||||
from .spiral import Spiral
|
from .spiral import Spiral
|
||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
from .xor import XOR
|
||||||
|
@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@ -19,7 +20,7 @@ class Dataset(torch.utils.data.Dataset):
|
|||||||
_repr_indent = 2
|
_repr_indent = 2
|
||||||
|
|
||||||
def __init__(self, root):
|
def __init__(self, root):
|
||||||
if isinstance(root, torch._six.string_classes):
|
if isinstance(root, str):
|
||||||
root = os.path.expanduser(root)
|
root = os.path.expanduser(root)
|
||||||
self.root = root
|
self.root = root
|
||||||
|
|
||||||
@ -92,8 +93,23 @@ class ProtoDataset(Dataset):
|
|||||||
|
|
||||||
class NumpyDataset(torch.utils.data.TensorDataset):
|
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||||
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||||
|
|
||||||
def __init__(self, data, targets):
|
def __init__(self, data, targets):
|
||||||
self.data = torch.Tensor(data)
|
self.data = torch.Tensor(data)
|
||||||
self.targets = torch.LongTensor(targets)
|
self.targets = torch.LongTensor(targets)
|
||||||
tensors = [self.data, self.targets]
|
tensors = [self.data, self.targets]
|
||||||
super().__init__(*tensors)
|
super().__init__(*tensors)
|
||||||
|
|
||||||
|
|
||||||
|
class CSVDataset(NumpyDataset):
|
||||||
|
"""Create a Dataset from a CSV file."""
|
||||||
|
|
||||||
|
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
|
||||||
|
raw = np.genfromtxt(
|
||||||
|
filepath,
|
||||||
|
delimiter=delimiter,
|
||||||
|
skip_header=skip_header,
|
||||||
|
)
|
||||||
|
data = np.delete(raw, 1, target_col)
|
||||||
|
targets = raw[:, target_col]
|
||||||
|
super().__init__(data, targets)
|
||||||
|
@ -5,11 +5,18 @@ URL:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
from __future__ import annotations
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from sklearn.datasets import (load_iris, make_blobs, make_circles,
|
import warnings
|
||||||
make_classification, make_moons)
|
from typing import Sequence
|
||||||
|
|
||||||
|
from sklearn.datasets import (
|
||||||
|
load_iris,
|
||||||
|
make_blobs,
|
||||||
|
make_circles,
|
||||||
|
make_classification,
|
||||||
|
make_moons,
|
||||||
|
)
|
||||||
|
|
||||||
from prototorch.datasets.abstract import NumpyDataset
|
from prototorch.datasets.abstract import NumpyDataset
|
||||||
|
|
||||||
@ -35,9 +42,10 @@ class Iris(NumpyDataset):
|
|||||||
|
|
||||||
:param dims: select a subset of dimensions
|
:param dims: select a subset of dimensions
|
||||||
"""
|
"""
|
||||||
def __init__(self, dims: Sequence[int] = None):
|
|
||||||
|
def __init__(self, dims: Sequence[int] | None = None):
|
||||||
x, y = load_iris(return_X_y=True)
|
x, y = load_iris(return_X_y=True)
|
||||||
if dims:
|
if dims is not None:
|
||||||
x = x[:, dims]
|
x = x[:, dims]
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
@ -49,15 +57,20 @@ class Blobs(NumpyDataset):
|
|||||||
https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators.
|
https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
num_samples: int = 300,
|
def __init__(
|
||||||
num_features: int = 2,
|
self,
|
||||||
seed: Union[None, int] = 0):
|
num_samples: int = 300,
|
||||||
x, y = make_blobs(num_samples,
|
num_features: int = 2,
|
||||||
num_features,
|
seed: None | int = 0,
|
||||||
centers=None,
|
):
|
||||||
random_state=seed,
|
x, y = make_blobs(
|
||||||
shuffle=False)
|
num_samples,
|
||||||
|
num_features,
|
||||||
|
centers=None,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
@ -69,29 +82,34 @@ class Random(NumpyDataset):
|
|||||||
|
|
||||||
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
num_samples: int = 300,
|
def __init__(
|
||||||
num_features: int = 2,
|
self,
|
||||||
num_classes: int = 2,
|
num_samples: int = 300,
|
||||||
num_clusters: int = 2,
|
num_features: int = 2,
|
||||||
num_informative: Union[None, int] = None,
|
num_classes: int = 2,
|
||||||
separation: float = 1.0,
|
num_clusters: int = 2,
|
||||||
seed: Union[None, int] = 0):
|
num_informative: None | int = None,
|
||||||
|
separation: float = 1.0,
|
||||||
|
seed: None | int = 0,
|
||||||
|
):
|
||||||
if not num_informative:
|
if not num_informative:
|
||||||
import math
|
import math
|
||||||
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
||||||
if num_features < num_informative:
|
if num_features < num_informative:
|
||||||
warnings.warn("Generating more features than requested.")
|
warnings.warn("Generating more features than requested.")
|
||||||
num_features = num_informative
|
num_features = num_informative
|
||||||
x, y = make_classification(num_samples,
|
x, y = make_classification(
|
||||||
num_features,
|
num_samples,
|
||||||
n_informative=num_informative,
|
num_features,
|
||||||
n_redundant=0,
|
n_informative=num_informative,
|
||||||
n_classes=num_classes,
|
n_redundant=0,
|
||||||
n_clusters_per_class=num_clusters,
|
n_classes=num_classes,
|
||||||
class_sep=separation,
|
n_clusters_per_class=num_clusters,
|
||||||
random_state=seed,
|
class_sep=separation,
|
||||||
shuffle=False)
|
random_state=seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
@ -104,16 +122,21 @@ class Circles(NumpyDataset):
|
|||||||
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html
|
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
num_samples: int = 300,
|
def __init__(
|
||||||
noise: float = 0.3,
|
self,
|
||||||
factor: float = 0.8,
|
num_samples: int = 300,
|
||||||
seed: Union[None, int] = 0):
|
noise: float = 0.3,
|
||||||
x, y = make_circles(num_samples,
|
factor: float = 0.8,
|
||||||
noise=noise,
|
seed: None | int = 0,
|
||||||
factor=factor,
|
):
|
||||||
random_state=seed,
|
x, y = make_circles(
|
||||||
shuffle=False)
|
num_samples,
|
||||||
|
noise=noise,
|
||||||
|
factor=factor,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
@ -126,12 +149,17 @@ class Moons(NumpyDataset):
|
|||||||
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html
|
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
num_samples: int = 300,
|
def __init__(
|
||||||
noise: float = 0.3,
|
self,
|
||||||
seed: Union[None, int] = 0):
|
num_samples: int = 300,
|
||||||
x, y = make_moons(num_samples,
|
noise: float = 0.3,
|
||||||
noise=noise,
|
seed: None | int = 0,
|
||||||
random_state=seed,
|
):
|
||||||
shuffle=False)
|
x, y = make_moons(
|
||||||
|
num_samples,
|
||||||
|
noise=noise,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
super().__init__(x, y)
|
super().__init__(x, y)
|
||||||
|
@ -9,6 +9,7 @@ def make_spiral(num_samples=500, noise=0.3):
|
|||||||
|
|
||||||
For use in Prototorch use `prototorch.datasets.Spiral` instead.
|
For use in Prototorch use `prototorch.datasets.Spiral` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_samples(n, delta_t):
|
def get_samples(n, delta_t):
|
||||||
points = []
|
points = []
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
@ -52,6 +53,7 @@ class Spiral(torch.utils.data.TensorDataset):
|
|||||||
:param num_samples: number of random samples
|
:param num_samples: number of random samples
|
||||||
:param noise: noise added to the spirals
|
:param noise: noise added to the spirals
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_samples: int = 500, noise: float = 0.3):
|
def __init__(self, num_samples: int = 500, noise: float = 0.3):
|
||||||
x, y = make_spiral(num_samples, noise)
|
x, y = make_spiral(num_samples, noise)
|
||||||
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
||||||
|
@ -36,6 +36,7 @@ Description:
|
|||||||
are determined by analytic chemistry.
|
are determined by analytic chemistry.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -81,13 +82,11 @@ class Tecator(ProtoDataset):
|
|||||||
if self._check_exists():
|
if self._check_exists():
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Making directories...")
|
||||||
print("Making directories...")
|
|
||||||
os.makedirs(self.raw_folder, exist_ok=True)
|
os.makedirs(self.raw_folder, exist_ok=True)
|
||||||
os.makedirs(self.processed_folder, exist_ok=True)
|
os.makedirs(self.processed_folder, exist_ok=True)
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Downloading...")
|
||||||
print("Downloading...")
|
|
||||||
for fileid, md5 in self._resources:
|
for fileid, md5 in self._resources:
|
||||||
filename = "tecator.npz"
|
filename = "tecator.npz"
|
||||||
download_file_from_google_drive(fileid,
|
download_file_from_google_drive(fileid,
|
||||||
@ -95,8 +94,7 @@ class Tecator(ProtoDataset):
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
md5=md5)
|
md5=md5)
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Processing...")
|
||||||
print("Processing...")
|
|
||||||
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
with np.load(os.path.join(self.raw_folder, "tecator.npz"),
|
||||||
allow_pickle=False) as f:
|
allow_pickle=False) as f:
|
||||||
x_train, y_train = f["x_train"], f["y_train"]
|
x_train, y_train = f["x_train"], f["y_train"]
|
||||||
@ -117,5 +115,4 @@ class Tecator(ProtoDataset):
|
|||||||
"wb") as f:
|
"wb") as f:
|
||||||
torch.save(test_set, f)
|
torch.save(test_set, f)
|
||||||
|
|
||||||
if self.verbose:
|
logging.debug("Done!")
|
||||||
print("Done!")
|
|
||||||
|
19
prototorch/datasets/xor.py
Normal file
19
prototorch/datasets/xor.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_xor(num_samples=500):
|
||||||
|
x = torch.rand(num_samples, 2)
|
||||||
|
y = torch.zeros(num_samples)
|
||||||
|
y[torch.logical_and(x[:, 0] > 0.5, x[:, 1] < 0.5)] = 1
|
||||||
|
y[torch.logical_and(x[:, 1] > 0.5, x[:, 0] < 0.5)] = 1
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class XOR(torch.utils.data.TensorDataset):
|
||||||
|
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||||
|
|
||||||
|
def __init__(self, num_samples: int = 500):
|
||||||
|
x, y = make_xor(num_samples)
|
||||||
|
super().__init__(x, y)
|
@ -4,6 +4,7 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class LambdaLayer(torch.nn.Module):
|
class LambdaLayer(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, fn, name=None):
|
def __init__(self, fn, name=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
@ -17,6 +18,7 @@ class LambdaLayer(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LossLayer(torch.nn.modules.loss._Loss):
|
class LossLayer(torch.nn.modules.loss._Loss):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
fn,
|
fn,
|
||||||
name=None,
|
name=None,
|
||||||
|
@ -1,6 +1,11 @@
|
|||||||
"""ProtoFlow utils module"""
|
"""ProtoTorch utils module"""
|
||||||
|
|
||||||
from .colors import hex_to_rgb, rgb_to_hex
|
from .colors import (
|
||||||
|
get_colors,
|
||||||
|
get_legend_handles,
|
||||||
|
hex_to_rgb,
|
||||||
|
rgb_to_hex,
|
||||||
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
mesh2d,
|
mesh2d,
|
||||||
parse_data_arg,
|
parse_data_arg,
|
||||||
|
@ -1,4 +1,13 @@
|
|||||||
"""ProtoFlow color utilities"""
|
"""ProtoTorch color utilities"""
|
||||||
|
|
||||||
|
import matplotlib.lines as mlines
|
||||||
|
import torch
|
||||||
|
from matplotlib import cm
|
||||||
|
from matplotlib.colors import (
|
||||||
|
Normalize,
|
||||||
|
to_hex,
|
||||||
|
to_rgb,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def hex_to_rgb(hex_values):
|
def hex_to_rgb(hex_values):
|
||||||
@ -13,3 +22,39 @@ def rgb_to_hex(rgb_values):
|
|||||||
for v in rgb_values:
|
for v in rgb_values:
|
||||||
c = "%02x%02x%02x" % tuple(v)
|
c = "%02x%02x%02x" % tuple(v)
|
||||||
yield c
|
yield c
|
||||||
|
|
||||||
|
|
||||||
|
def get_colors(vmax, vmin=0, cmap="viridis"):
|
||||||
|
cmap = cm.get_cmap(cmap)
|
||||||
|
colornorm = Normalize(vmin=vmin, vmax=vmax)
|
||||||
|
colors = dict()
|
||||||
|
for c in range(vmin, vmax + 1):
|
||||||
|
colors[c] = to_hex(cmap(colornorm(c)))
|
||||||
|
return colors
|
||||||
|
|
||||||
|
|
||||||
|
def get_legend_handles(colors, labels, marker="dots", zero_indexed=False):
|
||||||
|
handles = list()
|
||||||
|
for color, label in zip(colors.values(), labels):
|
||||||
|
if marker == "dots":
|
||||||
|
handle = mlines.Line2D(
|
||||||
|
xdata=[],
|
||||||
|
ydata=[],
|
||||||
|
label=label,
|
||||||
|
color="white",
|
||||||
|
markerfacecolor=color,
|
||||||
|
marker="o",
|
||||||
|
markersize=10,
|
||||||
|
markeredgecolor="k",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
handle = mlines.Line2D(
|
||||||
|
xdata=[],
|
||||||
|
ydata=[],
|
||||||
|
label=label,
|
||||||
|
color=color,
|
||||||
|
marker="",
|
||||||
|
markersize=15,
|
||||||
|
)
|
||||||
|
handles.append(handle)
|
||||||
|
return handles
|
||||||
|
@ -1,14 +1,45 @@
|
|||||||
"""ProtoFlow utilities"""
|
"""ProtoTorch utilities"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Iterable
|
from typing import (
|
||||||
from typing import Union
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
|
def generate_mesh(
|
||||||
|
minima: torch.TensorType,
|
||||||
|
maxima: torch.TensorType,
|
||||||
|
border: float = 1.0,
|
||||||
|
resolution: int = 100,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
# Apply Border
|
||||||
|
ptp = maxima - minima
|
||||||
|
shift = border * ptp
|
||||||
|
minima -= shift
|
||||||
|
maxima += shift
|
||||||
|
|
||||||
|
# Generate Mesh
|
||||||
|
minima = minima.to(device).unsqueeze(1)
|
||||||
|
maxima = maxima.to(device).unsqueeze(1)
|
||||||
|
|
||||||
|
factors = torch.linspace(0, 1, resolution, device=device)
|
||||||
|
marginals = factors * maxima + ((1 - factors) * minima)
|
||||||
|
|
||||||
|
single_dimensions = torch.meshgrid(*marginals)
|
||||||
|
mesh_input = torch.stack([dim.ravel() for dim in single_dimensions], dim=1)
|
||||||
|
|
||||||
|
return mesh_input, single_dimensions
|
||||||
|
|
||||||
|
|
||||||
def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
||||||
if x is not None:
|
if x is not None:
|
||||||
x_shift = border * np.ptp(x[:, 0])
|
x_shift = border * np.ptp(x[:, 0])
|
||||||
@ -24,15 +55,16 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
|||||||
return mesh, xx, yy
|
return mesh, xx, yy
|
||||||
|
|
||||||
|
|
||||||
def distribution_from_list(list_dist: list[int],
|
def distribution_from_list(list_dist: List[int],
|
||||||
clabels: Iterable[int] = None):
|
clabels: Optional[Iterable[int]] = None):
|
||||||
clabels = clabels or list(range(len(list_dist)))
|
clabels = clabels or list(range(len(list_dist)))
|
||||||
distribution = dict(zip(clabels, list_dist))
|
distribution = dict(zip(clabels, list_dist))
|
||||||
return distribution
|
return distribution
|
||||||
|
|
||||||
|
|
||||||
def parse_distribution(user_distribution,
|
def parse_distribution(
|
||||||
clabels: Iterable[int] = None) -> dict[int, int]:
|
user_distribution,
|
||||||
|
clabels: Optional[Iterable[int]] = None) -> Dict[int, int]:
|
||||||
"""Parse user-provided distribution.
|
"""Parse user-provided distribution.
|
||||||
|
|
||||||
Return a dictionary with integer keys that represent the class labels and
|
Return a dictionary with integer keys that represent the class labels and
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
[pylint]
|
[pylint]
|
||||||
disable =
|
disable =
|
||||||
too-many-arguments,
|
too-many-arguments,
|
||||||
too-few-public-methods,
|
too-few-public-methods,
|
||||||
fixme,
|
fixme,
|
||||||
|
|
||||||
|
|
||||||
[pycodestyle]
|
[pycodestyle]
|
||||||
max-line-length = 79
|
max-line-length = 79
|
||||||
@ -12,4 +13,4 @@ multi_line_output = 3
|
|||||||
include_trailing_comma = True
|
include_trailing_comma = True
|
||||||
force_grid_wrap = 3
|
force_grid_wrap = 3
|
||||||
use_parentheses = True
|
use_parentheses = True
|
||||||
line_length = 79
|
line_length = 79
|
||||||
|
37
setup.py
37
setup.py
@ -15,21 +15,22 @@ from setuptools import find_packages, setup
|
|||||||
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
PROJECT_URL = "https://github.com/si-cim/prototorch"
|
||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
|
||||||
|
|
||||||
with open("README.md", "r") as fh:
|
with open("README.md", encoding="utf-8") as fh:
|
||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
INSTALL_REQUIRES = [
|
INSTALL_REQUIRES = [
|
||||||
"torch>=1.3.1",
|
"torch>=2.0.0",
|
||||||
"torchvision>=0.6.0",
|
"torchvision",
|
||||||
"numpy>=1.9.1",
|
"numpy",
|
||||||
"sklearn",
|
"scikit-learn",
|
||||||
|
"matplotlib",
|
||||||
]
|
]
|
||||||
DATASETS = [
|
DATASETS = [
|
||||||
"requests",
|
"requests",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
]
|
]
|
||||||
DEV = [
|
DEV = [
|
||||||
"bumpversion",
|
"bump2version",
|
||||||
"pre-commit",
|
"pre-commit",
|
||||||
]
|
]
|
||||||
DOCS = [
|
DOCS = [
|
||||||
@ -40,15 +41,17 @@ DOCS = [
|
|||||||
"sphinx-autodoc-typehints",
|
"sphinx-autodoc-typehints",
|
||||||
]
|
]
|
||||||
EXAMPLES = [
|
EXAMPLES = [
|
||||||
"matplotlib",
|
|
||||||
"torchinfo",
|
"torchinfo",
|
||||||
]
|
]
|
||||||
TESTS = ["codecov", "pytest"]
|
TESTS = [
|
||||||
|
"flake8",
|
||||||
|
"pytest",
|
||||||
|
]
|
||||||
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="prototorch",
|
name="prototorch",
|
||||||
version="0.6.0",
|
version="0.7.6",
|
||||||
description="Highly extensible, GPU-supported "
|
description="Highly extensible, GPU-supported "
|
||||||
"Learning Vector Quantization (LVQ) toolbox "
|
"Learning Vector Quantization (LVQ) toolbox "
|
||||||
"built using PyTorch and its nn API.",
|
"built using PyTorch and its nn API.",
|
||||||
@ -59,7 +62,7 @@ setup(
|
|||||||
url=PROJECT_URL,
|
url=PROJECT_URL,
|
||||||
download_url=DOWNLOAD_URL,
|
download_url=DOWNLOAD_URL,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
python_requires=">=3.9",
|
python_requires=">=3.8",
|
||||||
install_requires=INSTALL_REQUIRES,
|
install_requires=INSTALL_REQUIRES,
|
||||||
extras_require={
|
extras_require={
|
||||||
"datasets": DATASETS,
|
"datasets": DATASETS,
|
||||||
@ -70,18 +73,22 @@ setup(
|
|||||||
"all": ALL,
|
"all": ALL,
|
||||||
},
|
},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 2 - Pre-Alpha",
|
|
||||||
"Environment :: Console",
|
"Environment :: Console",
|
||||||
|
"Natural Language :: English",
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Intended Audience :: Education",
|
"Intended Audience :: Education",
|
||||||
"Intended Audience :: Science/Research",
|
"Intended Audience :: Science/Research",
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Natural Language :: English",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Operating System :: OS Independent",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Software Development :: Libraries",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
],
|
],
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
|
@ -245,33 +245,45 @@ def test_random_reasonings_init_channels_not_first():
|
|||||||
|
|
||||||
# Transform initializers
|
# Transform initializers
|
||||||
def test_eye_transform_init_square():
|
def test_eye_transform_init_square():
|
||||||
t = pt.initializers.EyeTransformInitializer()
|
t = pt.initializers.EyeLinearTransformInitializer()
|
||||||
I = t.generate(3, 3)
|
I = t.generate(3, 3)
|
||||||
assert torch.allclose(I, torch.eye(3))
|
assert torch.allclose(I, torch.eye(3))
|
||||||
|
|
||||||
|
|
||||||
def test_eye_transform_init_narrow():
|
def test_eye_transform_init_narrow():
|
||||||
t = pt.initializers.EyeTransformInitializer()
|
t = pt.initializers.EyeLinearTransformInitializer()
|
||||||
actual = t.generate(3, 2)
|
actual = t.generate(3, 2)
|
||||||
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
|
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
|
||||||
assert torch.allclose(actual, desired)
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
def test_eye_transform_init_wide():
|
def test_eye_transform_init_wide():
|
||||||
t = pt.initializers.EyeTransformInitializer()
|
t = pt.initializers.EyeLinearTransformInitializer()
|
||||||
actual = t.generate(2, 3)
|
actual = t.generate(2, 3)
|
||||||
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
||||||
assert torch.allclose(actual, desired)
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
def test_linear_transform():
|
def test_linear_transform_default_eye_init():
|
||||||
l = pt.transforms.LinearTransform(2, 4)
|
l = pt.transforms.LinearTransform(2, 4)
|
||||||
actual = l.weights
|
actual = l.weights
|
||||||
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||||
assert torch.allclose(actual, desired)
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_transform_forward():
|
||||||
|
l = pt.transforms.LinearTransform(4, 2)
|
||||||
|
actual_weights = l.weights
|
||||||
|
desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]])
|
||||||
|
assert torch.allclose(actual_weights, desired_weights)
|
||||||
|
actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \
|
||||||
|
[1.1, 2.2, 3.3, 4.4], \
|
||||||
|
[5.5, 6.6, 7.7, 8.8]]))
|
||||||
|
desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]])
|
||||||
|
assert torch.allclose(actual_outputs, desired_outputs)
|
||||||
|
|
||||||
|
|
||||||
def test_linear_transform_zeros_init():
|
def test_linear_transform_zeros_init():
|
||||||
l = pt.transforms.LinearTransform(
|
l = pt.transforms.LinearTransform(
|
||||||
in_dim=2,
|
in_dim=2,
|
||||||
@ -392,6 +404,7 @@ def test_glvq_loss_one_hot_unequal():
|
|||||||
|
|
||||||
# Activations
|
# Activations
|
||||||
class TestActivations(unittest.TestCase):
|
class TestActivations(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.flist = ["identity", "sigmoid_beta", "swish_beta"]
|
self.flist = ["identity", "sigmoid_beta", "swish_beta"]
|
||||||
self.x = torch.randn(1024, 1)
|
self.x = torch.randn(1024, 1)
|
||||||
@ -406,6 +419,7 @@ class TestActivations(unittest.TestCase):
|
|||||||
self.assertTrue(iscallable)
|
self.assertTrue(iscallable)
|
||||||
|
|
||||||
def test_callable_deserialization(self):
|
def test_callable_deserialization(self):
|
||||||
|
|
||||||
def dummy(x, **kwargs):
|
def dummy(x, **kwargs):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -450,6 +464,7 @@ class TestActivations(unittest.TestCase):
|
|||||||
|
|
||||||
# Competitions
|
# Competitions
|
||||||
class TestCompetitions(unittest.TestCase):
|
class TestCompetitions(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -503,6 +518,7 @@ class TestCompetitions(unittest.TestCase):
|
|||||||
|
|
||||||
# Pooling
|
# Pooling
|
||||||
class TestPooling(unittest.TestCase):
|
class TestPooling(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -603,6 +619,7 @@ class TestPooling(unittest.TestCase):
|
|||||||
|
|
||||||
# Distances
|
# Distances
|
||||||
class TestDistances(unittest.TestCase):
|
class TestDistances(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.nx, self.mx = 32, 2048
|
self.nx, self.mx = 32, 2048
|
||||||
self.ny, self.my = 8, 2048
|
self.ny, self.my = 8, 2048
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""ProtoTorch datasets test suite"""
|
"""ProtoTorch datasets test suite"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,6 +11,7 @@ from prototorch.datasets.abstract import Dataset, ProtoDataset
|
|||||||
|
|
||||||
|
|
||||||
class TestAbstract(unittest.TestCase):
|
class TestAbstract(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.ds = Dataset("./artifacts")
|
self.ds = Dataset("./artifacts")
|
||||||
|
|
||||||
@ -28,6 +28,7 @@ class TestAbstract(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class TestProtoDataset(unittest.TestCase):
|
class TestProtoDataset(unittest.TestCase):
|
||||||
|
|
||||||
def test_download(self):
|
def test_download(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
_ = ProtoDataset("./artifacts", download=True)
|
_ = ProtoDataset("./artifacts", download=True)
|
||||||
@ -38,6 +39,7 @@ class TestProtoDataset(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class TestNumpyDataset(unittest.TestCase):
|
class TestNumpyDataset(unittest.TestCase):
|
||||||
|
|
||||||
def test_list_init(self):
|
def test_list_init(self):
|
||||||
ds = pt.datasets.NumpyDataset([1], [1])
|
ds = pt.datasets.NumpyDataset([1], [1])
|
||||||
self.assertEqual(len(ds), 1)
|
self.assertEqual(len(ds), 1)
|
||||||
@ -49,13 +51,33 @@ class TestNumpyDataset(unittest.TestCase):
|
|||||||
self.assertEqual(len(ds), 3)
|
self.assertEqual(len(ds), 3)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCSVDataset(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
data = np.random.rand(100, 4)
|
||||||
|
targets = np.random.randint(2, size=(100, 1))
|
||||||
|
arr = np.hstack([data, targets])
|
||||||
|
if not os.path.exists("./artifacts"):
|
||||||
|
os.mkdir("./artifacts")
|
||||||
|
np.savetxt("./artifacts/test.csv", arr, delimiter=",")
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
ds = pt.datasets.CSVDataset("./artifacts/test.csv")
|
||||||
|
self.assertEqual(len(ds), 100)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
os.remove("./artifacts/test.csv")
|
||||||
|
|
||||||
|
|
||||||
class TestSpiral(unittest.TestCase):
|
class TestSpiral(unittest.TestCase):
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
ds = pt.datasets.Spiral(num_samples=10)
|
ds = pt.datasets.Spiral(num_samples=10)
|
||||||
self.assertEqual(len(ds), 10)
|
self.assertEqual(len(ds), 10)
|
||||||
|
|
||||||
|
|
||||||
class TestIris(unittest.TestCase):
|
class TestIris(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.ds = pt.datasets.Iris()
|
self.ds = pt.datasets.Iris()
|
||||||
|
|
||||||
@ -71,24 +93,28 @@ class TestIris(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class TestBlobs(unittest.TestCase):
|
class TestBlobs(unittest.TestCase):
|
||||||
|
|
||||||
def test_size(self):
|
def test_size(self):
|
||||||
ds = pt.datasets.Blobs(num_samples=10)
|
ds = pt.datasets.Blobs(num_samples=10)
|
||||||
self.assertEqual(len(ds), 10)
|
self.assertEqual(len(ds), 10)
|
||||||
|
|
||||||
|
|
||||||
class TestRandom(unittest.TestCase):
|
class TestRandom(unittest.TestCase):
|
||||||
|
|
||||||
def test_size(self):
|
def test_size(self):
|
||||||
ds = pt.datasets.Random(num_samples=10)
|
ds = pt.datasets.Random(num_samples=10)
|
||||||
self.assertEqual(len(ds), 10)
|
self.assertEqual(len(ds), 10)
|
||||||
|
|
||||||
|
|
||||||
class TestCircles(unittest.TestCase):
|
class TestCircles(unittest.TestCase):
|
||||||
|
|
||||||
def test_size(self):
|
def test_size(self):
|
||||||
ds = pt.datasets.Circles(num_samples=10)
|
ds = pt.datasets.Circles(num_samples=10)
|
||||||
self.assertEqual(len(ds), 10)
|
self.assertEqual(len(ds), 10)
|
||||||
|
|
||||||
|
|
||||||
class TestMoons(unittest.TestCase):
|
class TestMoons(unittest.TestCase):
|
||||||
|
|
||||||
def test_size(self):
|
def test_size(self):
|
||||||
ds = pt.datasets.Moons(num_samples=10)
|
ds = pt.datasets.Moons(num_samples=10)
|
||||||
self.assertEqual(len(ds), 10)
|
self.assertEqual(len(ds), 10)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user