Compare commits
27 Commits
v0.6.0
...
feature/je
Author | SHA1 | Date | |
---|---|---|---|
|
0788718c31 | ||
|
4f5c4ebe8f | ||
|
ae2a8e54ef | ||
|
d9be100c1f | ||
|
9d1dc7320f | ||
|
d11ab71b7e | ||
|
59037e1a50 | ||
|
a19b99be82 | ||
|
f7e7558338 | ||
|
d57648f9d6 | ||
|
d24f580bf0 | ||
|
916973c3e8 | ||
|
b49b7a2d41 | ||
|
b6e8242383 | ||
|
cd616d11b9 | ||
|
afcfcb8973 | ||
|
bf03a45475 | ||
|
083b5c1597 | ||
|
7f0a8e9bce | ||
|
bf09ff8f7f | ||
|
c1d7cfee8f | ||
|
99be965581 | ||
|
fdb9a7c66d | ||
|
eb79b703d8 | ||
|
bc9a826b7d | ||
|
cfe09ec06b | ||
|
3d76dffe3c |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.6.0
|
current_version = 0.7.1
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||||
|
7
.ci/python310.Dockerfile
Normal file
7
.ci/python310.Dockerfile
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
FROM python:3.9
|
||||||
|
|
||||||
|
RUN adduser --uid 1000 jenkins
|
||||||
|
|
||||||
|
USER jenkins
|
||||||
|
|
||||||
|
RUN mkdir -p /home/jenkins/.cache/pip
|
7
.ci/python36.Dockerfile
Normal file
7
.ci/python36.Dockerfile
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
FROM python:3.6
|
||||||
|
|
||||||
|
RUN adduser --uid 1000 jenkins
|
||||||
|
|
||||||
|
USER jenkins
|
||||||
|
|
||||||
|
RUN mkdir -p /home/jenkins/.cache/pip
|
19
.github/ISSUE_TEMPLATE/bug_report.md
vendored
19
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -10,21 +10,28 @@ assignees: ''
|
|||||||
**Describe the bug**
|
**Describe the bug**
|
||||||
A clear and concise description of what the bug is.
|
A clear and concise description of what the bug is.
|
||||||
|
|
||||||
**To Reproduce**
|
**Steps to reproduce the behavior**
|
||||||
Steps to reproduce the behavior:
|
1. ...
|
||||||
1. Install Prototorch by running '...'
|
2. Run script '...' or this snippet:
|
||||||
2. Run script '...'
|
```python
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
...
|
||||||
|
```
|
||||||
3. See errors
|
3. See errors
|
||||||
|
|
||||||
**Expected behavior**
|
**Expected behavior**
|
||||||
A clear and concise description of what you expected to happen.
|
A clear and concise description of what you expected to happen.
|
||||||
|
|
||||||
|
**Observed behavior**
|
||||||
|
A clear and concise description of what actually happened.
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
If applicable, add screenshots to help explain your problem.
|
If applicable, add screenshots to help explain your problem.
|
||||||
|
|
||||||
**Desktop (please complete the following information):**
|
**System and version information**
|
||||||
- OS: [e.g. Ubuntu 20.10]
|
- OS: [e.g. Ubuntu 20.10]
|
||||||
- Prototorch Version: [e.g. v0.4.0]
|
- ProtoTorch Version: [e.g. 0.4.0]
|
||||||
- Python Version: [e.g. 3.9.5]
|
- Python Version: [e.g. 3.9.5]
|
||||||
|
|
||||||
**Additional context**
|
**Additional context**
|
||||||
|
36
.travis.yml
36
.travis.yml
@@ -1,36 +0,0 @@
|
|||||||
dist: bionic
|
|
||||||
sudo: false
|
|
||||||
language: python
|
|
||||||
python: 3.9
|
|
||||||
cache:
|
|
||||||
directories:
|
|
||||||
- "$HOME/.cache/pip"
|
|
||||||
- "./tests/artifacts"
|
|
||||||
- "$HOME/datasets"
|
|
||||||
install:
|
|
||||||
- pip install .[all] --progress-bar off
|
|
||||||
|
|
||||||
# Generate code coverage report
|
|
||||||
script:
|
|
||||||
- coverage run -m pytest
|
|
||||||
|
|
||||||
# Push the results to codecov
|
|
||||||
after_success:
|
|
||||||
- bash <(curl -s https://codecov.io/bash)
|
|
||||||
|
|
||||||
# Publish on PyPI
|
|
||||||
deploy:
|
|
||||||
provider: pypi
|
|
||||||
username: __token__
|
|
||||||
password:
|
|
||||||
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
|
||||||
on:
|
|
||||||
tags: true
|
|
||||||
skip_existing: true
|
|
||||||
|
|
||||||
# The password is encrypted with:
|
|
||||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
|
||||||
# See https://docs.travis-ci.com/user/deployment/pypi and
|
|
||||||
# https://github.com/travis-ci/travis.rb#installation
|
|
||||||
# for more details
|
|
||||||
# Note: The encrypt command does not work well in ZSH.
|
|
41
Jenkinsfile
vendored
Normal file
41
Jenkinsfile
vendored
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
pipeline {
|
||||||
|
agent none
|
||||||
|
stages {
|
||||||
|
stage('Unit Tests') {
|
||||||
|
parallel {
|
||||||
|
stage('3.6'){
|
||||||
|
agent{
|
||||||
|
dockerfile {
|
||||||
|
filename 'python36.Dockerfile'
|
||||||
|
dir '.ci'
|
||||||
|
args '-v pip-cache:/home/jenkins/.cache/pip'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
steps {
|
||||||
|
sh 'pip install pip --upgrade --progress-bar off'
|
||||||
|
sh 'pip install .[all] --progress-bar off'
|
||||||
|
sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
|
||||||
|
cobertura coberturaReportFile: 'reports/coverage.xml'
|
||||||
|
junit 'reports/**/*.xml'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stage('3.10'){
|
||||||
|
agent{
|
||||||
|
dockerfile {
|
||||||
|
filename 'python310.Dockerfile'
|
||||||
|
dir '.ci'
|
||||||
|
args '-v pip-cache:/home/jenkins/.cache/pip'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
steps {
|
||||||
|
sh 'pip install pip --upgrade --progress-bar off'
|
||||||
|
sh 'pip install .[all] --progress-bar off'
|
||||||
|
sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
|
||||||
|
cobertura coberturaReportFile: 'reports/coverage.xml'
|
||||||
|
junit 'reports/**/*.xml'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
46
deprecated.travis.yml
Normal file
46
deprecated.travis.yml
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
dist: bionic
|
||||||
|
sudo: false
|
||||||
|
language: python
|
||||||
|
python:
|
||||||
|
- 3.9
|
||||||
|
- 3.8
|
||||||
|
- 3.7
|
||||||
|
- 3.6
|
||||||
|
cache:
|
||||||
|
directories:
|
||||||
|
- "$HOME/.cache/pip"
|
||||||
|
- "./tests/artifacts"
|
||||||
|
- "$HOME/datasets"
|
||||||
|
install:
|
||||||
|
- pip install .[all] --progress-bar off
|
||||||
|
|
||||||
|
# Generate code coverage report
|
||||||
|
script:
|
||||||
|
- coverage run -m pytest
|
||||||
|
|
||||||
|
# Push the results to codecov
|
||||||
|
after_success:
|
||||||
|
- bash <(curl -s https://codecov.io/bash)
|
||||||
|
|
||||||
|
# Publish on PyPI
|
||||||
|
jobs:
|
||||||
|
include:
|
||||||
|
- stage: build
|
||||||
|
python: 3.9
|
||||||
|
script: echo "Starting Pypi build"
|
||||||
|
deploy:
|
||||||
|
provider: pypi
|
||||||
|
username: __token__
|
||||||
|
distributions: "sdist bdist_wheel"
|
||||||
|
password:
|
||||||
|
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
||||||
|
on:
|
||||||
|
tags: true
|
||||||
|
skip_existing: true
|
||||||
|
|
||||||
|
# The password is encrypted with:
|
||||||
|
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
||||||
|
# See https://docs.travis-ci.com/user/deployment/pypi and
|
||||||
|
# https://github.com/travis-ci/travis.rb#installation
|
||||||
|
# for more details
|
||||||
|
# Note: The encrypt command does not work well in ZSH.
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "0.6.0"
|
release = "0.7.1"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
@@ -22,7 +22,7 @@ from .core import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Core Setup
|
# Core Setup
|
||||||
__version__ = "0.6.0"
|
__version__ = "0.7.1"
|
||||||
|
|
||||||
__all_core__ = [
|
__all_core__ = [
|
||||||
"competitions",
|
"competitions",
|
||||||
|
@@ -48,7 +48,7 @@ class WTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, distances, labels):
|
def forward(self, distances, labels): # pylint: disable=no-self-use
|
||||||
return wtac(distances, labels)
|
return wtac(distances, labels)
|
||||||
|
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ class LTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, probs, labels):
|
def forward(self, probs, labels): # pylint: disable=no-self-use
|
||||||
return wtac(-1.0 * probs, labels)
|
return wtac(-1.0 * probs, labels)
|
||||||
|
|
||||||
|
|
||||||
@@ -85,5 +85,5 @@ class CBCC(torch.nn.Module):
|
|||||||
Thin wrapper over the `cbcc` function.
|
Thin wrapper over the `cbcc` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def forward(self, detections, reasonings):
|
def forward(self, detections, reasonings): # pylint: disable=no-self-use
|
||||||
return cbcc(detections, reasonings)
|
return cbcc(detections, reasonings)
|
||||||
|
@@ -253,8 +253,10 @@ class Reasonings(torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
distribution: Union[dict, list, tuple],
|
distribution: Union[dict, list, tuple],
|
||||||
initializer:
|
initializer:
|
||||||
AbstractReasoningsInitializer = RandomReasoningsInitializer()):
|
AbstractReasoningsInitializer = RandomReasoningsInitializer(),
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.add_reasonings(distribution, initializer)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
|
@@ -41,9 +41,6 @@ def euclidean_distance_v2(x, y):
|
|||||||
# batch diagonal. See:
|
# batch diagonal. See:
|
||||||
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||||
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||||
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
|
||||||
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
|
||||||
# print(f"{distances.shape=}") # (nx, ny)
|
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
@@ -303,17 +303,18 @@ class OneHotLabelsInitializer(LabelsInitializer):
|
|||||||
|
|
||||||
|
|
||||||
# Reasonings
|
# Reasonings
|
||||||
|
def compute_distribution_shape(distribution):
|
||||||
|
distribution = parse_distribution(distribution)
|
||||||
|
num_components = sum(distribution.values())
|
||||||
|
num_classes = len(distribution.keys())
|
||||||
|
return (num_components, num_classes, 2)
|
||||||
|
|
||||||
|
|
||||||
class AbstractReasoningsInitializer(ABC):
|
class AbstractReasoningsInitializer(ABC):
|
||||||
"""Abstract class for all reasonings initializers."""
|
"""Abstract class for all reasonings initializers."""
|
||||||
def __init__(self, components_first: bool = True):
|
def __init__(self, components_first: bool = True):
|
||||||
self.components_first = components_first
|
self.components_first = components_first
|
||||||
|
|
||||||
def compute_shape(self, distribution):
|
|
||||||
distribution = parse_distribution(distribution)
|
|
||||||
num_components = sum(distribution.values())
|
|
||||||
num_classes = len(distribution.keys())
|
|
||||||
return (num_components, num_classes, 2)
|
|
||||||
|
|
||||||
def generate_end_hook(self, reasonings):
|
def generate_end_hook(self, reasonings):
|
||||||
if not self.components_first:
|
if not self.components_first:
|
||||||
reasonings = reasonings.permute(2, 1, 0)
|
reasonings = reasonings.permute(2, 1, 0)
|
||||||
@@ -349,7 +350,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with zeros."""
|
"""Reasonings are all initialized with zeros."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.zeros(*shape)
|
reasonings = torch.zeros(*shape)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@@ -358,7 +359,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with ones."""
|
"""Reasonings are all initialized with ones."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.ones(*shape)
|
reasonings = torch.ones(*shape)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@@ -372,7 +373,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
self.maximum = maximum
|
self.maximum = maximum
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = self.compute_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
@@ -381,7 +382,8 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Each component reasons positively for exactly one class."""
|
"""Each component reasons positively for exactly one class."""
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
num_components, num_classes, _ = compute_distribution_shape(
|
||||||
|
distribution)
|
||||||
A = OneHotLabelsInitializer().generate(distribution)
|
A = OneHotLabelsInitializer().generate(distribution)
|
||||||
B = torch.zeros(num_components, num_classes)
|
B = torch.zeros(num_components, num_classes)
|
||||||
reasonings = torch.stack([A, B], dim=-1)
|
reasonings = torch.stack([A, B], dim=-1)
|
||||||
|
@@ -106,17 +106,16 @@ def margin_loss(y_pred, y_true, margin=0.3):
|
|||||||
|
|
||||||
|
|
||||||
class GLVQLoss(torch.nn.Module):
|
class GLVQLoss(torch.nn.Module):
|
||||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.squashing = get_activation(squashing)
|
self.transfer_fn = get_activation(transfer_fn)
|
||||||
self.beta = torch.tensor(beta)
|
self.beta = torch.tensor(beta)
|
||||||
|
|
||||||
def forward(self, outputs, targets):
|
def forward(self, outputs, targets, plabels):
|
||||||
distances, plabels = outputs
|
mu = glvq_loss(outputs, targets, prototype_labels=plabels)
|
||||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
|
||||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
return batch_loss.sum()
|
||||||
return torch.sum(batch_loss, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||||
|
@@ -82,23 +82,23 @@ def stratified_prod_pooling(values: torch.Tensor,
|
|||||||
|
|
||||||
class StratifiedSumPooling(torch.nn.Module):
|
class StratifiedSumPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_sum_pooling(values, labels)
|
return stratified_sum_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedProdPooling(torch.nn.Module):
|
class StratifiedProdPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_prod_pooling(values, labels)
|
return stratified_prod_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMinPooling(torch.nn.Module):
|
class StratifiedMinPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_min_pooling(values, labels)
|
return stratified_min_pooling(values, labels)
|
||||||
|
|
||||||
|
|
||||||
class StratifiedMaxPooling(torch.nn.Module):
|
class StratifiedMaxPooling(torch.nn.Module):
|
||||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||||
def forward(self, values, labels):
|
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||||
return stratified_max_pooling(values, labels)
|
return stratified_max_pooling(values, labels)
|
||||||
|
@@ -36,7 +36,7 @@ class LinearTransform(torch.nn.Module):
|
|||||||
self._register_weights(weights)
|
self._register_weights(weights)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x @ self.weights.T
|
return x @ self.weights
|
||||||
|
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
"""ProtoTorch datasets"""
|
"""ProtoTorch datasets"""
|
||||||
|
|
||||||
from .abstract import NumpyDataset
|
from .abstract import CSVDataset, NumpyDataset
|
||||||
from .sklearn import (
|
from .sklearn import (
|
||||||
Blobs,
|
Blobs,
|
||||||
Circles,
|
Circles,
|
||||||
@@ -10,3 +10,4 @@ from .sklearn import (
|
|||||||
)
|
)
|
||||||
from .spiral import Spiral
|
from .spiral import Spiral
|
||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
from .xor import XOR
|
||||||
|
@@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset):
|
|||||||
self.targets = torch.LongTensor(targets)
|
self.targets = torch.LongTensor(targets)
|
||||||
tensors = [self.data, self.targets]
|
tensors = [self.data, self.targets]
|
||||||
super().__init__(*tensors)
|
super().__init__(*tensors)
|
||||||
|
|
||||||
|
|
||||||
|
class CSVDataset(NumpyDataset):
|
||||||
|
"""Create a Dataset from a CSV file."""
|
||||||
|
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
|
||||||
|
raw = np.genfromtxt(
|
||||||
|
filepath,
|
||||||
|
delimiter=delimiter,
|
||||||
|
skip_header=skip_header,
|
||||||
|
)
|
||||||
|
data = np.delete(raw, 1, target_col)
|
||||||
|
targets = raw[:, target_col]
|
||||||
|
super().__init__(data, targets)
|
||||||
|
18
prototorch/datasets/xor.py
Normal file
18
prototorch/datasets/xor.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def make_xor(num_samples=500):
|
||||||
|
x = torch.rand(num_samples, 2)
|
||||||
|
y = torch.zeros(num_samples)
|
||||||
|
y[torch.logical_and(x[:, 0] > 0.5, x[:, 1] < 0.5)] = 1
|
||||||
|
y[torch.logical_and(x[:, 1] > 0.5, x[:, 0] < 0.5)] = 1
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
class XOR(torch.utils.data.TensorDataset):
|
||||||
|
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||||
|
def __init__(self, num_samples: int = 500):
|
||||||
|
x, y = make_xor(num_samples)
|
||||||
|
super().__init__(x, y)
|
@@ -1,8 +1,12 @@
|
|||||||
"""ProtoFlow utilities"""
|
"""ProtoFlow utilities"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Iterable
|
from typing import (
|
||||||
from typing import Union
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -24,7 +28,7 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
|||||||
return mesh, xx, yy
|
return mesh, xx, yy
|
||||||
|
|
||||||
|
|
||||||
def distribution_from_list(list_dist: list[int],
|
def distribution_from_list(list_dist: List[int],
|
||||||
clabels: Iterable[int] = None):
|
clabels: Iterable[int] = None):
|
||||||
clabels = clabels or list(range(len(list_dist)))
|
clabels = clabels or list(range(len(list_dist)))
|
||||||
distribution = dict(zip(clabels, list_dist))
|
distribution = dict(zip(clabels, list_dist))
|
||||||
@@ -32,7 +36,7 @@ def distribution_from_list(list_dist: list[int],
|
|||||||
|
|
||||||
|
|
||||||
def parse_distribution(user_distribution,
|
def parse_distribution(user_distribution,
|
||||||
clabels: Iterable[int] = None) -> dict[int, int]:
|
clabels: Iterable[int] = None) -> Dict[int, int]:
|
||||||
"""Parse user-provided distribution.
|
"""Parse user-provided distribution.
|
||||||
|
|
||||||
Return a dictionary with integer keys that represent the class labels and
|
Return a dictionary with integer keys that represent the class labels and
|
||||||
|
25
setup.py
25
setup.py
@@ -20,7 +20,7 @@ with open("README.md", "r") as fh:
|
|||||||
|
|
||||||
INSTALL_REQUIRES = [
|
INSTALL_REQUIRES = [
|
||||||
"torch>=1.3.1",
|
"torch>=1.3.1",
|
||||||
"torchvision>=0.6.0",
|
"torchvision>=0.7.1",
|
||||||
"numpy>=1.9.1",
|
"numpy>=1.9.1",
|
||||||
"sklearn",
|
"sklearn",
|
||||||
]
|
]
|
||||||
@@ -43,12 +43,15 @@ EXAMPLES = [
|
|||||||
"matplotlib",
|
"matplotlib",
|
||||||
"torchinfo",
|
"torchinfo",
|
||||||
]
|
]
|
||||||
TESTS = ["codecov", "pytest"]
|
TESTS = [
|
||||||
|
"pytest-cov",
|
||||||
|
"pytest",
|
||||||
|
]
|
||||||
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="prototorch",
|
name="prototorch",
|
||||||
version="0.6.0",
|
version="0.7.1",
|
||||||
description="Highly extensible, GPU-supported "
|
description="Highly extensible, GPU-supported "
|
||||||
"Learning Vector Quantization (LVQ) toolbox "
|
"Learning Vector Quantization (LVQ) toolbox "
|
||||||
"built using PyTorch and its nn API.",
|
"built using PyTorch and its nn API.",
|
||||||
@@ -59,7 +62,7 @@ setup(
|
|||||||
url=PROJECT_URL,
|
url=PROJECT_URL,
|
||||||
download_url=DOWNLOAD_URL,
|
download_url=DOWNLOAD_URL,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
python_requires=">=3.9",
|
python_requires=">=3.6",
|
||||||
install_requires=INSTALL_REQUIRES,
|
install_requires=INSTALL_REQUIRES,
|
||||||
extras_require={
|
extras_require={
|
||||||
"datasets": DATASETS,
|
"datasets": DATASETS,
|
||||||
@@ -70,18 +73,22 @@ setup(
|
|||||||
"all": ALL,
|
"all": ALL,
|
||||||
},
|
},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 2 - Pre-Alpha",
|
|
||||||
"Environment :: Console",
|
"Environment :: Console",
|
||||||
|
"Natural Language :: English",
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
"Intended Audience :: Developers",
|
"Intended Audience :: Developers",
|
||||||
"Intended Audience :: Education",
|
"Intended Audience :: Education",
|
||||||
"Intended Audience :: Science/Research",
|
"Intended Audience :: Science/Research",
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Natural Language :: English",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Operating System :: OS Independent",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
"Topic :: Software Development :: Libraries",
|
"Topic :: Software Development :: Libraries",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.6",
|
||||||
|
"Programming Language :: Python :: 3.7",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
],
|
],
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
|
@@ -265,13 +265,25 @@ def test_eye_transform_init_wide():
|
|||||||
|
|
||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
def test_linear_transform():
|
def test_linear_transform_default_eye_init():
|
||||||
l = pt.transforms.LinearTransform(2, 4)
|
l = pt.transforms.LinearTransform(2, 4)
|
||||||
actual = l.weights
|
actual = l.weights
|
||||||
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||||
assert torch.allclose(actual, desired)
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_transform_forward():
|
||||||
|
l = pt.transforms.LinearTransform(4, 2)
|
||||||
|
actual_weights = l.weights
|
||||||
|
desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]])
|
||||||
|
assert torch.allclose(actual_weights, desired_weights)
|
||||||
|
actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \
|
||||||
|
[1.1, 2.2, 3.3, 4.4], \
|
||||||
|
[5.5, 6.6, 7.7, 8.8]]))
|
||||||
|
desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]])
|
||||||
|
assert torch.allclose(actual_outputs, desired_outputs)
|
||||||
|
|
||||||
|
|
||||||
def test_linear_transform_zeros_init():
|
def test_linear_transform_zeros_init():
|
||||||
l = pt.transforms.LinearTransform(
|
l = pt.transforms.LinearTransform(
|
||||||
in_dim=2,
|
in_dim=2,
|
||||||
|
@@ -49,6 +49,23 @@ class TestNumpyDataset(unittest.TestCase):
|
|||||||
self.assertEqual(len(ds), 3)
|
self.assertEqual(len(ds), 3)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCSVDataset(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
data = np.random.rand(100, 4)
|
||||||
|
targets = np.random.randint(2, size=(100, 1))
|
||||||
|
arr = np.hstack([data, targets])
|
||||||
|
if not os.path.exists("./artifacts"):
|
||||||
|
os.mkdir("./artifacts")
|
||||||
|
np.savetxt("./artifacts/test.csv", arr, delimiter=",")
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
ds = pt.datasets.CSVDataset("./artifacts/test.csv")
|
||||||
|
self.assertEqual(len(ds), 100)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
os.remove("./artifacts/test.csv")
|
||||||
|
|
||||||
|
|
||||||
class TestSpiral(unittest.TestCase):
|
class TestSpiral(unittest.TestCase):
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
ds = pt.datasets.Spiral(num_samples=10)
|
ds = pt.datasets.Spiral(num_samples=10)
|
||||||
|
Reference in New Issue
Block a user