Compare commits
27 Commits
v0.6.0
...
feature/je
Author | SHA1 | Date | |
---|---|---|---|
|
0788718c31 | ||
|
4f5c4ebe8f | ||
|
ae2a8e54ef | ||
|
d9be100c1f | ||
|
9d1dc7320f | ||
|
d11ab71b7e | ||
|
59037e1a50 | ||
|
a19b99be82 | ||
|
f7e7558338 | ||
|
d57648f9d6 | ||
|
d24f580bf0 | ||
|
916973c3e8 | ||
|
b49b7a2d41 | ||
|
b6e8242383 | ||
|
cd616d11b9 | ||
|
afcfcb8973 | ||
|
bf03a45475 | ||
|
083b5c1597 | ||
|
7f0a8e9bce | ||
|
bf09ff8f7f | ||
|
c1d7cfee8f | ||
|
99be965581 | ||
|
fdb9a7c66d | ||
|
eb79b703d8 | ||
|
bc9a826b7d | ||
|
cfe09ec06b | ||
|
3d76dffe3c |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.6.0
|
||||
current_version = 0.7.1
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
|
||||
|
7
.ci/python310.Dockerfile
Normal file
7
.ci/python310.Dockerfile
Normal file
@@ -0,0 +1,7 @@
|
||||
FROM python:3.9
|
||||
|
||||
RUN adduser --uid 1000 jenkins
|
||||
|
||||
USER jenkins
|
||||
|
||||
RUN mkdir -p /home/jenkins/.cache/pip
|
7
.ci/python36.Dockerfile
Normal file
7
.ci/python36.Dockerfile
Normal file
@@ -0,0 +1,7 @@
|
||||
FROM python:3.6
|
||||
|
||||
RUN adduser --uid 1000 jenkins
|
||||
|
||||
USER jenkins
|
||||
|
||||
RUN mkdir -p /home/jenkins/.cache/pip
|
19
.github/ISSUE_TEMPLATE/bug_report.md
vendored
19
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -10,21 +10,28 @@ assignees: ''
|
||||
**Describe the bug**
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
**To Reproduce**
|
||||
Steps to reproduce the behavior:
|
||||
1. Install Prototorch by running '...'
|
||||
2. Run script '...'
|
||||
**Steps to reproduce the behavior**
|
||||
1. ...
|
||||
2. Run script '...' or this snippet:
|
||||
```python
|
||||
import prototorch as pt
|
||||
|
||||
...
|
||||
```
|
||||
3. See errors
|
||||
|
||||
**Expected behavior**
|
||||
A clear and concise description of what you expected to happen.
|
||||
|
||||
**Observed behavior**
|
||||
A clear and concise description of what actually happened.
|
||||
|
||||
**Screenshots**
|
||||
If applicable, add screenshots to help explain your problem.
|
||||
|
||||
**Desktop (please complete the following information):**
|
||||
**System and version information**
|
||||
- OS: [e.g. Ubuntu 20.10]
|
||||
- Prototorch Version: [e.g. v0.4.0]
|
||||
- ProtoTorch Version: [e.g. 0.4.0]
|
||||
- Python Version: [e.g. 3.9.5]
|
||||
|
||||
**Additional context**
|
||||
|
36
.travis.yml
36
.travis.yml
@@ -1,36 +0,0 @@
|
||||
dist: bionic
|
||||
sudo: false
|
||||
language: python
|
||||
python: 3.9
|
||||
cache:
|
||||
directories:
|
||||
- "$HOME/.cache/pip"
|
||||
- "./tests/artifacts"
|
||||
- "$HOME/datasets"
|
||||
install:
|
||||
- pip install .[all] --progress-bar off
|
||||
|
||||
# Generate code coverage report
|
||||
script:
|
||||
- coverage run -m pytest
|
||||
|
||||
# Push the results to codecov
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
|
||||
# Publish on PyPI
|
||||
deploy:
|
||||
provider: pypi
|
||||
username: __token__
|
||||
password:
|
||||
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
||||
on:
|
||||
tags: true
|
||||
skip_existing: true
|
||||
|
||||
# The password is encrypted with:
|
||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
||||
# See https://docs.travis-ci.com/user/deployment/pypi and
|
||||
# https://github.com/travis-ci/travis.rb#installation
|
||||
# for more details
|
||||
# Note: The encrypt command does not work well in ZSH.
|
41
Jenkinsfile
vendored
Normal file
41
Jenkinsfile
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
pipeline {
|
||||
agent none
|
||||
stages {
|
||||
stage('Unit Tests') {
|
||||
parallel {
|
||||
stage('3.6'){
|
||||
agent{
|
||||
dockerfile {
|
||||
filename 'python36.Dockerfile'
|
||||
dir '.ci'
|
||||
args '-v pip-cache:/home/jenkins/.cache/pip'
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'pip install pip --upgrade --progress-bar off'
|
||||
sh 'pip install .[all] --progress-bar off'
|
||||
sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
|
||||
cobertura coberturaReportFile: 'reports/coverage.xml'
|
||||
junit 'reports/**/*.xml'
|
||||
}
|
||||
}
|
||||
stage('3.10'){
|
||||
agent{
|
||||
dockerfile {
|
||||
filename 'python310.Dockerfile'
|
||||
dir '.ci'
|
||||
args '-v pip-cache:/home/jenkins/.cache/pip'
|
||||
}
|
||||
}
|
||||
steps {
|
||||
sh 'pip install pip --upgrade --progress-bar off'
|
||||
sh 'pip install .[all] --progress-bar off'
|
||||
sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
|
||||
cobertura coberturaReportFile: 'reports/coverage.xml'
|
||||
junit 'reports/**/*.xml'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
46
deprecated.travis.yml
Normal file
46
deprecated.travis.yml
Normal file
@@ -0,0 +1,46 @@
|
||||
dist: bionic
|
||||
sudo: false
|
||||
language: python
|
||||
python:
|
||||
- 3.9
|
||||
- 3.8
|
||||
- 3.7
|
||||
- 3.6
|
||||
cache:
|
||||
directories:
|
||||
- "$HOME/.cache/pip"
|
||||
- "./tests/artifacts"
|
||||
- "$HOME/datasets"
|
||||
install:
|
||||
- pip install .[all] --progress-bar off
|
||||
|
||||
# Generate code coverage report
|
||||
script:
|
||||
- coverage run -m pytest
|
||||
|
||||
# Push the results to codecov
|
||||
after_success:
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
|
||||
# Publish on PyPI
|
||||
jobs:
|
||||
include:
|
||||
- stage: build
|
||||
python: 3.9
|
||||
script: echo "Starting Pypi build"
|
||||
deploy:
|
||||
provider: pypi
|
||||
username: __token__
|
||||
distributions: "sdist bdist_wheel"
|
||||
password:
|
||||
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
|
||||
on:
|
||||
tags: true
|
||||
skip_existing: true
|
||||
|
||||
# The password is encrypted with:
|
||||
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
|
||||
# See https://docs.travis-ci.com/user/deployment/pypi and
|
||||
# https://github.com/travis-ci/travis.rb#installation
|
||||
# for more details
|
||||
# Note: The encrypt command does not work well in ZSH.
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
#
|
||||
release = "0.6.0"
|
||||
release = "0.7.1"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
|
@@ -22,7 +22,7 @@ from .core import (
|
||||
)
|
||||
|
||||
# Core Setup
|
||||
__version__ = "0.6.0"
|
||||
__version__ = "0.7.1"
|
||||
|
||||
__all_core__ = [
|
||||
"competitions",
|
||||
|
@@ -48,7 +48,7 @@ class WTAC(torch.nn.Module):
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, distances, labels):
|
||||
def forward(self, distances, labels): # pylint: disable=no-self-use
|
||||
return wtac(distances, labels)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class LTAC(torch.nn.Module):
|
||||
Thin wrapper over the `wtac` function.
|
||||
|
||||
"""
|
||||
def forward(self, probs, labels):
|
||||
def forward(self, probs, labels): # pylint: disable=no-self-use
|
||||
return wtac(-1.0 * probs, labels)
|
||||
|
||||
|
||||
@@ -85,5 +85,5 @@ class CBCC(torch.nn.Module):
|
||||
Thin wrapper over the `cbcc` function.
|
||||
|
||||
"""
|
||||
def forward(self, detections, reasonings):
|
||||
def forward(self, detections, reasonings): # pylint: disable=no-self-use
|
||||
return cbcc(detections, reasonings)
|
||||
|
@@ -253,8 +253,10 @@ class Reasonings(torch.nn.Module):
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
initializer:
|
||||
AbstractReasoningsInitializer = RandomReasoningsInitializer()):
|
||||
AbstractReasoningsInitializer = RandomReasoningsInitializer(),
|
||||
):
|
||||
super().__init__()
|
||||
self.add_reasonings(distribution, initializer)
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
|
@@ -41,9 +41,6 @@ def euclidean_distance_v2(x, y):
|
||||
# batch diagonal. See:
|
||||
# https://pytorch.org/docs/stable/generated/torch.diagonal.html
|
||||
distances = torch.diagonal(pairwise_distances, dim1=-2, dim2=-1)
|
||||
# print(f"{diff.shape=}") # (nx, ny, ndim)
|
||||
# print(f"{pairwise_distances.shape=}") # (nx, ny, ny)
|
||||
# print(f"{distances.shape=}") # (nx, ny)
|
||||
return distances
|
||||
|
||||
|
||||
|
@@ -303,17 +303,18 @@ class OneHotLabelsInitializer(LabelsInitializer):
|
||||
|
||||
|
||||
# Reasonings
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
def __init__(self, components_first: bool = True):
|
||||
self.components_first = components_first
|
||||
|
||||
def compute_shape(self, distribution):
|
||||
def compute_distribution_shape(distribution):
|
||||
distribution = parse_distribution(distribution)
|
||||
num_components = sum(distribution.values())
|
||||
num_classes = len(distribution.keys())
|
||||
return (num_components, num_classes, 2)
|
||||
|
||||
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
def __init__(self, components_first: bool = True):
|
||||
self.components_first = components_first
|
||||
|
||||
def generate_end_hook(self, reasonings):
|
||||
if not self.components_first:
|
||||
reasonings = reasonings.permute(2, 1, 0)
|
||||
@@ -349,7 +350,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with zeros."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
shape = compute_distribution_shape(distribution)
|
||||
reasonings = torch.zeros(*shape)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
@@ -358,7 +359,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Reasonings are all initialized with ones."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
shape = compute_distribution_shape(distribution)
|
||||
reasonings = torch.ones(*shape)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
@@ -372,7 +373,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
self.maximum = maximum
|
||||
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
shape = self.compute_shape(distribution)
|
||||
shape = compute_distribution_shape(distribution)
|
||||
reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
@@ -381,7 +382,8 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
"""Each component reasons positively for exactly one class."""
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
||||
num_components, num_classes, _ = compute_distribution_shape(
|
||||
distribution)
|
||||
A = OneHotLabelsInitializer().generate(distribution)
|
||||
B = torch.zeros(num_components, num_classes)
|
||||
reasonings = torch.stack([A, B], dim=-1)
|
||||
|
@@ -106,17 +106,16 @@ def margin_loss(y_pred, y_true, margin=0.3):
|
||||
|
||||
|
||||
class GLVQLoss(torch.nn.Module):
|
||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||
def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
self.transfer_fn = get_activation(transfer_fn)
|
||||
self.beta = torch.tensor(beta)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
||||
def forward(self, outputs, targets, plabels):
|
||||
mu = glvq_loss(outputs, targets, prototype_labels=plabels)
|
||||
batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
|
||||
return batch_loss.sum()
|
||||
|
||||
|
||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||
|
@@ -82,23 +82,23 @@ def stratified_prod_pooling(values: torch.Tensor,
|
||||
|
||||
class StratifiedSumPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_sum_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||
return stratified_sum_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedProdPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_prod_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||
return stratified_prod_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMinPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_min_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||
return stratified_min_pooling(values, labels)
|
||||
|
||||
|
||||
class StratifiedMaxPooling(torch.nn.Module):
|
||||
"""Thin wrapper over the `stratified_max_pooling` function."""
|
||||
def forward(self, values, labels):
|
||||
def forward(self, values, labels): # pylint: disable=no-self-use
|
||||
return stratified_max_pooling(values, labels)
|
||||
|
@@ -36,7 +36,7 @@ class LinearTransform(torch.nn.Module):
|
||||
self._register_weights(weights)
|
||||
|
||||
def forward(self, x):
|
||||
return x @ self.weights.T
|
||||
return x @ self.weights
|
||||
|
||||
|
||||
# Aliases
|
||||
|
@@ -1,6 +1,6 @@
|
||||
"""ProtoTorch datasets"""
|
||||
|
||||
from .abstract import NumpyDataset
|
||||
from .abstract import CSVDataset, NumpyDataset
|
||||
from .sklearn import (
|
||||
Blobs,
|
||||
Circles,
|
||||
@@ -10,3 +10,4 @@ from .sklearn import (
|
||||
)
|
||||
from .spiral import Spiral
|
||||
from .tecator import Tecator
|
||||
from .xor import XOR
|
||||
|
@@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset):
|
||||
self.targets = torch.LongTensor(targets)
|
||||
tensors = [self.data, self.targets]
|
||||
super().__init__(*tensors)
|
||||
|
||||
|
||||
class CSVDataset(NumpyDataset):
|
||||
"""Create a Dataset from a CSV file."""
|
||||
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
|
||||
raw = np.genfromtxt(
|
||||
filepath,
|
||||
delimiter=delimiter,
|
||||
skip_header=skip_header,
|
||||
)
|
||||
data = np.delete(raw, 1, target_col)
|
||||
targets = raw[:, target_col]
|
||||
super().__init__(data, targets)
|
||||
|
18
prototorch/datasets/xor.py
Normal file
18
prototorch/datasets/xor.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def make_xor(num_samples=500):
|
||||
x = torch.rand(num_samples, 2)
|
||||
y = torch.zeros(num_samples)
|
||||
y[torch.logical_and(x[:, 0] > 0.5, x[:, 1] < 0.5)] = 1
|
||||
y[torch.logical_and(x[:, 1] > 0.5, x[:, 0] < 0.5)] = 1
|
||||
return x, y
|
||||
|
||||
|
||||
class XOR(torch.utils.data.TensorDataset):
|
||||
"""Exclusive-or (XOR) dataset for binary classification."""
|
||||
def __init__(self, num_samples: int = 500):
|
||||
x, y = make_xor(num_samples)
|
||||
super().__init__(x, y)
|
@@ -1,8 +1,12 @@
|
||||
"""ProtoFlow utilities"""
|
||||
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
from typing import (
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -24,7 +28,7 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
|
||||
return mesh, xx, yy
|
||||
|
||||
|
||||
def distribution_from_list(list_dist: list[int],
|
||||
def distribution_from_list(list_dist: List[int],
|
||||
clabels: Iterable[int] = None):
|
||||
clabels = clabels or list(range(len(list_dist)))
|
||||
distribution = dict(zip(clabels, list_dist))
|
||||
@@ -32,7 +36,7 @@ def distribution_from_list(list_dist: list[int],
|
||||
|
||||
|
||||
def parse_distribution(user_distribution,
|
||||
clabels: Iterable[int] = None) -> dict[int, int]:
|
||||
clabels: Iterable[int] = None) -> Dict[int, int]:
|
||||
"""Parse user-provided distribution.
|
||||
|
||||
Return a dictionary with integer keys that represent the class labels and
|
||||
|
25
setup.py
25
setup.py
@@ -20,7 +20,7 @@ with open("README.md", "r") as fh:
|
||||
|
||||
INSTALL_REQUIRES = [
|
||||
"torch>=1.3.1",
|
||||
"torchvision>=0.6.0",
|
||||
"torchvision>=0.7.1",
|
||||
"numpy>=1.9.1",
|
||||
"sklearn",
|
||||
]
|
||||
@@ -43,12 +43,15 @@ EXAMPLES = [
|
||||
"matplotlib",
|
||||
"torchinfo",
|
||||
]
|
||||
TESTS = ["codecov", "pytest"]
|
||||
TESTS = [
|
||||
"pytest-cov",
|
||||
"pytest",
|
||||
]
|
||||
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
|
||||
|
||||
setup(
|
||||
name="prototorch",
|
||||
version="0.6.0",
|
||||
version="0.7.1",
|
||||
description="Highly extensible, GPU-supported "
|
||||
"Learning Vector Quantization (LVQ) toolbox "
|
||||
"built using PyTorch and its nn API.",
|
||||
@@ -59,7 +62,7 @@ setup(
|
||||
url=PROJECT_URL,
|
||||
download_url=DOWNLOAD_URL,
|
||||
license="MIT",
|
||||
python_requires=">=3.9",
|
||||
python_requires=">=3.6",
|
||||
install_requires=INSTALL_REQUIRES,
|
||||
extras_require={
|
||||
"datasets": DATASETS,
|
||||
@@ -70,18 +73,22 @@ setup(
|
||||
"all": ALL,
|
||||
},
|
||||
classifiers=[
|
||||
"Development Status :: 2 - Pre-Alpha",
|
||||
"Environment :: Console",
|
||||
"Natural Language :: English",
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Natural Language :: English",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Operating System :: OS Independent",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
],
|
||||
packages=find_packages(),
|
||||
zip_safe=False,
|
||||
|
@@ -265,13 +265,25 @@ def test_eye_transform_init_wide():
|
||||
|
||||
|
||||
# Transforms
|
||||
def test_linear_transform():
|
||||
def test_linear_transform_default_eye_init():
|
||||
l = pt.transforms.LinearTransform(2, 4)
|
||||
actual = l.weights
|
||||
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_linear_transform_forward():
|
||||
l = pt.transforms.LinearTransform(4, 2)
|
||||
actual_weights = l.weights
|
||||
desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]])
|
||||
assert torch.allclose(actual_weights, desired_weights)
|
||||
actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \
|
||||
[1.1, 2.2, 3.3, 4.4], \
|
||||
[5.5, 6.6, 7.7, 8.8]]))
|
||||
desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]])
|
||||
assert torch.allclose(actual_outputs, desired_outputs)
|
||||
|
||||
|
||||
def test_linear_transform_zeros_init():
|
||||
l = pt.transforms.LinearTransform(
|
||||
in_dim=2,
|
||||
|
@@ -49,6 +49,23 @@ class TestNumpyDataset(unittest.TestCase):
|
||||
self.assertEqual(len(ds), 3)
|
||||
|
||||
|
||||
class TestCSVDataset(unittest.TestCase):
|
||||
def setUp(self):
|
||||
data = np.random.rand(100, 4)
|
||||
targets = np.random.randint(2, size=(100, 1))
|
||||
arr = np.hstack([data, targets])
|
||||
if not os.path.exists("./artifacts"):
|
||||
os.mkdir("./artifacts")
|
||||
np.savetxt("./artifacts/test.csv", arr, delimiter=",")
|
||||
|
||||
def test_len(self):
|
||||
ds = pt.datasets.CSVDataset("./artifacts/test.csv")
|
||||
self.assertEqual(len(ds), 100)
|
||||
|
||||
def tearDown(self):
|
||||
os.remove("./artifacts/test.csv")
|
||||
|
||||
|
||||
class TestSpiral(unittest.TestCase):
|
||||
def test_init(self):
|
||||
ds = pt.datasets.Spiral(num_samples=10)
|
||||
|
Reference in New Issue
Block a user