Use github actions for CI (#10)

* chore: Absolute imports

* feat: Add new mesh util

* chore: replace bumpversion

original fork no longer maintained, move config

* ci: remove old configuration files

* ci: update github action

* ci: add python 3.10 test

* chore: update pre-commit hooks

* ci: update supported python versions

supported are 3.7, 3.8 and 3.9.

3.6 had EOL in december 2021.
3.10 has no pytorch distribution yet.

* ci: add windows test

* ci: update action

less windows tests, pre commit

* ci: fix typo

* chore: run precommit for all files

* ci: two step tests

* ci: compatibility waits for style

* fix: init file had missing imports

* ci: add deployment script

* ci: skip complete publish step

* ci: cleanup readme
This commit is contained in:
Alexander Engelsberger 2022-01-10 20:23:18 +01:00 committed by GitHub
parent b49b7a2d41
commit a28601751e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 205 additions and 123 deletions

View File

@ -1,13 +0,0 @@
[bumpversion]
current_version = 0.7.1
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = {major}.{minor}.{patch}
message = build: bump version {current_version} → {new_version}
[bumpversion:file:setup.py]
[bumpversion:file:./prototorch/__init__.py]
[bumpversion:file:./docs/source/conf.py]

View File

@ -1,15 +0,0 @@
# To validate the contents of your configuration file
# run the following command in the folder where the configuration file is located:
# codacy-analysis-cli validate-configuration --directory `pwd`
# To analyse, run:
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
---
engines:
pylintpython3:
exclude_paths:
- config/engines.yml
remark-lint:
exclude_paths:
- config/engines.yml
exclude_paths:
- 'tests/**'

View File

@ -1,2 +0,0 @@
comment:
require_changes: yes

View File

@ -5,33 +5,69 @@ name: tests
on: on:
push: push:
branches: [ master, dev ]
pull_request: pull_request:
branches: [ master ] branches: [ master ]
jobs: jobs:
build: style:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python 3.9 - name: Set up Python 3.9
uses: actions/setup-python@v1 uses: actions/setup-python@v2
with: with:
python-version: 3.9 python-version: 3.9
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[all] pip install .[all]
- name: Lint with flake8 - uses: pre-commit/action@v2.0.3
compatibility:
needs: style
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9"]
os: [ubuntu-latest, windows-latest]
exclude:
- os: windows-latest
python-version: "3.7"
- os: windows-latest
python-version: "3.8"
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: | run: |
pip install flake8 python -m pip install --upgrade pip
# stop the build if there are Python syntax errors or undefined names pip install .[all]
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest - name: Test with pytest
run: | run: |
pip install pytest
pytest pytest
publish_pypi:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs: compatibility
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: "3.9"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
pip install wheel
- name: Build package
run: python setup.py sdist bdist_wheel
- name: Publish a Python distribution to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

View File

@ -3,7 +3,7 @@
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1 rev: v4.1.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: end-of-file-fixer - id: end-of-file-fixer
@ -18,19 +18,19 @@ repos:
- id: autoflake - id: autoflake
- repo: http://github.com/PyCQA/isort - repo: http://github.com/PyCQA/isort
rev: 5.8.0 rev: 5.10.1
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.902 rev: v0.931
hooks: hooks:
- id: mypy - id: mypy
files: prototorch files: prototorch
additional_dependencies: [types-pkg_resources] additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.31.0 rev: v0.32.0
hooks: hooks:
- id: yapf - id: yapf
@ -42,7 +42,7 @@ repos:
- id: python-check-blanket-noqa - id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v2.19.4 rev: v2.31.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade

View File

@ -1,46 +0,0 @@
dist: bionic
sudo: false
language: python
python:
- 3.9
- 3.8
- 3.7
- 3.6
cache:
directories:
- "$HOME/.cache/pip"
- "./tests/artifacts"
- "$HOME/datasets"
install:
- pip install .[all] --progress-bar off
# Generate code coverage report
script:
- coverage run -m pytest
# Push the results to codecov
after_success:
- bash <(curl -s https://codecov.io/bash)
# Publish on PyPI
jobs:
include:
- stage: build
python: 3.9
script: echo "Starting Pypi build"
deploy:
provider: pypi
username: __token__
distributions: "sdist bdist_wheel"
password:
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
on:
tags: true
skip_existing: true
# The password is encrypted with:
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
# See https://docs.travis-ci.com/user/deployment/pypi and
# https://github.com/travis-ci/travis.rb#installation
# for more details
# Note: The encrypt command does not work well in ZSH.

View File

@ -2,12 +2,9 @@
![ProtoTorch Logo](https://prototorch.readthedocs.io/en/latest/_static/horizontal-lockup.png) ![ProtoTorch Logo](https://prototorch.readthedocs.io/en/latest/_static/horizontal-lockup.png)
[![Build Status](https://api.travis-ci.com/si-cim/prototorch.svg?branch=master)](https://travis-ci.com/github/si-cim/prototorch)
![tests](https://github.com/si-cim/prototorch/workflows/tests/badge.svg) ![tests](https://github.com/si-cim/prototorch/workflows/tests/badge.svg)
[![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch?color=yellow&label=version)](https://github.com/si-cim/prototorch/releases) [![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch?color=yellow&label=version)](https://github.com/si-cim/prototorch/releases)
[![PyPI](https://img.shields.io/pypi/v/prototorch)](https://pypi.org/project/prototorch/) [![PyPI](https://img.shields.io/pypi/v/prototorch)](https://pypi.org/project/prototorch/)
[![codecov](https://codecov.io/gh/si-cim/prototorch/branch/master/graph/badge.svg)](https://codecov.io/gh/si-cim/prototorch)
[![Codacy Badge](https://api.codacy.com/project/badge/Grade/76273904bf9343f0a8b29cd8aca242e7)](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=si-cim/prototorch&amp;utm_campaign=Badge_Grade)
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch)](https://github.com/si-cim/prototorch/blob/master/LICENSE) [![GitHub license](https://img.shields.io/github/license/si-cim/prototorch)](https://github.com/si-cim/prototorch/blob/master/LICENSE)
*Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow) *Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow)

View File

@ -7,6 +7,7 @@ import prototorch as pt
class CBC(torch.nn.Module): class CBC(torch.nn.Module):
def __init__(self, data, **kwargs): def __init__(self, data, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.components_layer = pt.components.ReasoningComponents( self.components_layer = pt.components.ReasoningComponents(
@ -23,6 +24,7 @@ class CBC(torch.nn.Module):
class VisCBC2D(): class VisCBC2D():
def __init__(self, model, data): def __init__(self, model, data):
self.model = model self.model = model
self.x_train, self.y_train = pt.utils.parse_data_arg(data) self.x_train, self.y_train = pt.utils.parse_data_arg(data)

View File

@ -1,25 +1,20 @@
"""ProtoTorch package""" """ProtoTorch package"""
import pkgutil import pkgutil
from typing import List
import pkg_resources import pkg_resources
from . import ( from . import datasets # noqa: F401
datasets, from . import nn # noqa: F401
nn, from . import utils # noqa: F401
utils, from .core import competitions # noqa: F401
) from .core import components # noqa: F401
from .core import ( from .core import distances # noqa: F401
competitions, from .core import initializers # noqa: F401
components, from .core import losses # noqa: F401
distances, from .core import pooling # noqa: F401
initializers, from .core import similarities # noqa: F401
losses, from .core import transforms # noqa: F401
pooling,
similarities,
transforms,
)
# Core Setup # Core Setup
__version__ = "0.7.1" __version__ = "0.7.1"
@ -40,7 +35,7 @@ __all_core__ = [
] ]
# Plugin Loader # Plugin Loader
__path__: List[str] = pkgutil.extend_path(__path__, __name__) __path__ = pkgutil.extend_path(__path__, __name__)
def discover_plugins(): def discover_plugins():

View File

@ -48,6 +48,7 @@ class WTAC(torch.nn.Module):
Thin wrapper over the `wtac` function. Thin wrapper over the `wtac` function.
""" """
def forward(self, distances, labels): # pylint: disable=no-self-use def forward(self, distances, labels): # pylint: disable=no-self-use
return wtac(distances, labels) return wtac(distances, labels)
@ -58,6 +59,7 @@ class LTAC(torch.nn.Module):
Thin wrapper over the `wtac` function. Thin wrapper over the `wtac` function.
""" """
def forward(self, probs, labels): # pylint: disable=no-self-use def forward(self, probs, labels): # pylint: disable=no-self-use
return wtac(-1.0 * probs, labels) return wtac(-1.0 * probs, labels)
@ -68,6 +70,7 @@ class KNNC(torch.nn.Module):
Thin wrapper over the `knnc` function. Thin wrapper over the `knnc` function.
""" """
def __init__(self, k=1, **kwargs): def __init__(self, k=1, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.k = k self.k = k
@ -85,5 +88,6 @@ class CBCC(torch.nn.Module):
Thin wrapper over the `cbcc` function. Thin wrapper over the `cbcc` function.
""" """
def forward(self, detections, reasonings): # pylint: disable=no-self-use def forward(self, detections, reasonings): # pylint: disable=no-self-use
return cbcc(detections, reasonings) return cbcc(detections, reasonings)

View File

@ -6,7 +6,8 @@ from typing import Union
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..utils import parse_distribution from prototorch.utils import parse_distribution
from .initializers import ( from .initializers import (
AbstractClassAwareCompInitializer, AbstractClassAwareCompInitializer,
AbstractComponentsInitializer, AbstractComponentsInitializer,
@ -63,6 +64,7 @@ def get_cikwargs(init, distribution):
class AbstractComponents(torch.nn.Module): class AbstractComponents(torch.nn.Module):
"""Abstract class for all components modules.""" """Abstract class for all components modules."""
@property @property
def num_components(self): def num_components(self):
"""Current number of components.""" """Current number of components."""
@ -85,6 +87,7 @@ class AbstractComponents(torch.nn.Module):
class Components(AbstractComponents): class Components(AbstractComponents):
"""A set of adaptable Tensors.""" """A set of adaptable Tensors."""
def __init__(self, num_components: int, def __init__(self, num_components: int,
initializer: AbstractComponentsInitializer): initializer: AbstractComponentsInitializer):
super().__init__() super().__init__()
@ -112,6 +115,7 @@ class Components(AbstractComponents):
class AbstractLabels(torch.nn.Module): class AbstractLabels(torch.nn.Module):
"""Abstract class for all labels modules.""" """Abstract class for all labels modules."""
@property @property
def labels(self): def labels(self):
return self._labels.cpu() return self._labels.cpu()
@ -152,6 +156,7 @@ class AbstractLabels(torch.nn.Module):
class Labels(AbstractLabels): class Labels(AbstractLabels):
"""A set of standalone labels.""" """A set of standalone labels."""
def __init__(self, def __init__(self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],
initializer: AbstractLabelsInitializer = LabelsInitializer()): initializer: AbstractLabelsInitializer = LabelsInitializer()):
@ -182,6 +187,7 @@ class Labels(AbstractLabels):
class LabeledComponents(AbstractComponents): class LabeledComponents(AbstractComponents):
"""A set of adaptable components and corresponding unadaptable labels.""" """A set of adaptable components and corresponding unadaptable labels."""
def __init__( def __init__(
self, self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],
@ -249,6 +255,7 @@ class Reasonings(torch.nn.Module):
The `reasonings` tensor is of shape [num_components, num_classes, 2]. The `reasonings` tensor is of shape [num_components, num_classes, 2].
""" """
def __init__( def __init__(
self, self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],
@ -308,6 +315,7 @@ class ReasoningComponents(AbstractComponents):
three element probability distribution. three element probability distribution.
""" """
def __init__( def __init__(
self, self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],

View File

@ -11,7 +11,7 @@ from typing import (
import torch import torch
from ..utils import parse_data_arg, parse_distribution from prototorch.utils import parse_data_arg, parse_distribution
# Components # Components
@ -26,6 +26,7 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
Use this to 'generate' pre-initialized components elsewhere. Use this to 'generate' pre-initialized components elsewhere.
""" """
def __init__(self, components): def __init__(self, components):
self.components = components self.components = components
@ -40,6 +41,7 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
class ShapeAwareCompInitializer(AbstractComponentsInitializer): class ShapeAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all dimension-aware components initializers.""" """Abstract class for all dimension-aware components initializers."""
def __init__(self, shape: Union[Iterable, int]): def __init__(self, shape: Union[Iterable, int]):
if isinstance(shape, Iterable): if isinstance(shape, Iterable):
self.component_shape = tuple(shape) self.component_shape = tuple(shape)
@ -53,6 +55,7 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer):
class ZerosCompInitializer(ShapeAwareCompInitializer): class ZerosCompInitializer(ShapeAwareCompInitializer):
"""Generate zeros corresponding to the components shape.""" """Generate zeros corresponding to the components shape."""
def generate(self, num_components: int): def generate(self, num_components: int):
components = torch.zeros((num_components, ) + self.component_shape) components = torch.zeros((num_components, ) + self.component_shape)
return components return components
@ -60,6 +63,7 @@ class ZerosCompInitializer(ShapeAwareCompInitializer):
class OnesCompInitializer(ShapeAwareCompInitializer): class OnesCompInitializer(ShapeAwareCompInitializer):
"""Generate ones corresponding to the components shape.""" """Generate ones corresponding to the components shape."""
def generate(self, num_components: int): def generate(self, num_components: int):
components = torch.ones((num_components, ) + self.component_shape) components = torch.ones((num_components, ) + self.component_shape)
return components return components
@ -67,6 +71,7 @@ class OnesCompInitializer(ShapeAwareCompInitializer):
class FillValueCompInitializer(OnesCompInitializer): class FillValueCompInitializer(OnesCompInitializer):
"""Generate components with the provided `fill_value`.""" """Generate components with the provided `fill_value`."""
def __init__(self, shape, fill_value: float = 1.0): def __init__(self, shape, fill_value: float = 1.0):
super().__init__(shape) super().__init__(shape)
self.fill_value = fill_value self.fill_value = fill_value
@ -79,6 +84,7 @@ class FillValueCompInitializer(OnesCompInitializer):
class UniformCompInitializer(OnesCompInitializer): class UniformCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a continuous uniform distribution.""" """Generate components by sampling from a continuous uniform distribution."""
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0): def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
super().__init__(shape) super().__init__(shape)
self.minimum = minimum self.minimum = minimum
@ -93,6 +99,7 @@ class UniformCompInitializer(OnesCompInitializer):
class RandomNormalCompInitializer(OnesCompInitializer): class RandomNormalCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a standard normal distribution.""" """Generate components by sampling from a standard normal distribution."""
def __init__(self, shape, shift=0.0, scale=1.0): def __init__(self, shape, shift=0.0, scale=1.0):
super().__init__(shape) super().__init__(shape)
self.shift = shift self.shift = shift
@ -113,6 +120,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
`data` has to be a torch tensor. `data` has to be a torch tensor.
""" """
def __init__(self, def __init__(self,
data: torch.Tensor, data: torch.Tensor,
noise: float = 0.0, noise: float = 0.0,
@ -137,6 +145,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
class DataAwareCompInitializer(AbstractDataAwareCompInitializer): class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
"""'Generate' the components from the provided data.""" """'Generate' the components from the provided data."""
def generate(self, num_components: int = 0): def generate(self, num_components: int = 0):
"""Ignore `num_components` and simply return transformed `self.data`.""" """Ignore `num_components` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data) components = self.generate_end_hook(self.data)
@ -145,6 +154,7 @@ class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
class SelectionCompInitializer(AbstractDataAwareCompInitializer): class SelectionCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by uniformly sampling from the provided data.""" """Generate components by uniformly sampling from the provided data."""
def generate(self, num_components: int): def generate(self, num_components: int):
indices = torch.LongTensor(num_components).random_(0, len(self.data)) indices = torch.LongTensor(num_components).random_(0, len(self.data))
samples = self.data[indices] samples = self.data[indices]
@ -154,6 +164,7 @@ class SelectionCompInitializer(AbstractDataAwareCompInitializer):
class MeanCompInitializer(AbstractDataAwareCompInitializer): class MeanCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by computing the mean of the provided data.""" """Generate components by computing the mean of the provided data."""
def generate(self, num_components: int): def generate(self, num_components: int):
mean = self.data.mean(dim=0) mean = self.data.mean(dim=0)
repeat_dim = [num_components] + [1] * len(mean.shape) repeat_dim = [num_components] + [1] * len(mean.shape)
@ -172,6 +183,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
target tensors. target tensors.
""" """
def __init__(self, def __init__(self,
data, data,
noise: float = 0.0, noise: float = 0.0,
@ -199,6 +211,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer): class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
"""'Generate' components from provided data and requested distribution.""" """'Generate' components from provided data and requested distribution."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
"""Ignore `distribution` and simply return transformed `self.data`.""" """Ignore `distribution` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data) components = self.generate_end_hook(self.data)
@ -207,6 +220,7 @@ class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer): class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
"""Abstract class for all stratified components initializers.""" """Abstract class for all stratified components initializers."""
@property @property
@abstractmethod @abstractmethod
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]: def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
@ -229,6 +243,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer): class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
"""Generate components using stratified sampling from the provided data.""" """Generate components using stratified sampling from the provided data."""
@property @property
def subinit_type(self): def subinit_type(self):
return SelectionCompInitializer return SelectionCompInitializer
@ -236,6 +251,7 @@ class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer): class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
"""Generate components at stratified means of the provided data.""" """Generate components at stratified means of the provided data."""
@property @property
def subinit_type(self): def subinit_type(self):
return MeanCompInitializer return MeanCompInitializer
@ -244,6 +260,7 @@ class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
# Labels # Labels
class AbstractLabelsInitializer(ABC): class AbstractLabelsInitializer(ABC):
"""Abstract class for all labels initializers.""" """Abstract class for all labels initializers."""
@abstractmethod @abstractmethod
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
... ...
@ -255,6 +272,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
Use this to 'generate' pre-initialized labels elsewhere. Use this to 'generate' pre-initialized labels elsewhere.
""" """
def __init__(self, labels): def __init__(self, labels):
self.labels = labels self.labels = labels
@ -273,6 +291,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
class DataAwareLabelsInitializer(AbstractLabelsInitializer): class DataAwareLabelsInitializer(AbstractLabelsInitializer):
"""'Generate' the labels from a torch Dataset.""" """'Generate' the labels from a torch Dataset."""
def __init__(self, data): def __init__(self, data):
self.data, self.targets = parse_data_arg(data) self.data, self.targets = parse_data_arg(data)
@ -283,6 +302,7 @@ class DataAwareLabelsInitializer(AbstractLabelsInitializer):
class LabelsInitializer(AbstractLabelsInitializer): class LabelsInitializer(AbstractLabelsInitializer):
"""Generate labels from `distribution`.""" """Generate labels from `distribution`."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution) distribution = parse_distribution(distribution)
labels_list = [] labels_list = []
@ -294,6 +314,7 @@ class LabelsInitializer(AbstractLabelsInitializer):
class OneHotLabelsInitializer(LabelsInitializer): class OneHotLabelsInitializer(LabelsInitializer):
"""Generate one-hot-encoded labels from `distribution`.""" """Generate one-hot-encoded labels from `distribution`."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution) distribution = parse_distribution(distribution)
num_classes = len(distribution.keys()) num_classes = len(distribution.keys())
@ -312,6 +333,7 @@ def compute_distribution_shape(distribution):
class AbstractReasoningsInitializer(ABC): class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers.""" """Abstract class for all reasonings initializers."""
def __init__(self, components_first: bool = True): def __init__(self, components_first: bool = True):
self.components_first = components_first self.components_first = components_first
@ -332,6 +354,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
Use this to 'generate' pre-initialized reasonings elsewhere. Use this to 'generate' pre-initialized reasonings elsewhere.
""" """
def __init__(self, reasonings, **kwargs): def __init__(self, reasonings, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.reasonings = reasonings self.reasonings = reasonings
@ -349,6 +372,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
class ZerosReasoningsInitializer(AbstractReasoningsInitializer): class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with zeros.""" """Reasonings are all initialized with zeros."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
shape = compute_distribution_shape(distribution) shape = compute_distribution_shape(distribution)
reasonings = torch.zeros(*shape) reasonings = torch.zeros(*shape)
@ -358,6 +382,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
class OnesReasoningsInitializer(AbstractReasoningsInitializer): class OnesReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with ones.""" """Reasonings are all initialized with ones."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
shape = compute_distribution_shape(distribution) shape = compute_distribution_shape(distribution)
reasonings = torch.ones(*shape) reasonings = torch.ones(*shape)
@ -367,6 +392,7 @@ class OnesReasoningsInitializer(AbstractReasoningsInitializer):
class RandomReasoningsInitializer(AbstractReasoningsInitializer): class RandomReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are randomly initialized.""" """Reasonings are randomly initialized."""
def __init__(self, minimum=0.4, maximum=0.6, **kwargs): def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.minimum = minimum self.minimum = minimum
@ -381,6 +407,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
"""Each component reasons positively for exactly one class.""" """Each component reasons positively for exactly one class."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
num_components, num_classes, _ = compute_distribution_shape( num_components, num_classes, _ = compute_distribution_shape(
distribution) distribution)
@ -399,6 +426,7 @@ class AbstractTransformInitializer(ABC):
class AbstractLinearTransformInitializer(AbstractTransformInitializer): class AbstractLinearTransformInitializer(AbstractTransformInitializer):
"""Abstract class for all linear transform initializers.""" """Abstract class for all linear transform initializers."""
def __init__(self, out_dim_first: bool = False): def __init__(self, out_dim_first: bool = False):
self.out_dim_first = out_dim_first self.out_dim_first = out_dim_first
@ -415,6 +443,7 @@ class AbstractLinearTransformInitializer(AbstractTransformInitializer):
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer): class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with zeros.""" """Initialize a matrix with zeros."""
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
weights = torch.zeros(in_dim, out_dim) weights = torch.zeros(in_dim, out_dim)
return self.generate_end_hook(weights) return self.generate_end_hook(weights)
@ -422,6 +451,7 @@ class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer): class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with ones.""" """Initialize a matrix with ones."""
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
weights = torch.ones(in_dim, out_dim) weights = torch.ones(in_dim, out_dim)
return self.generate_end_hook(weights) return self.generate_end_hook(weights)
@ -429,6 +459,7 @@ class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
class EyeTransformInitializer(AbstractLinearTransformInitializer): class EyeTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with the largest possible identity matrix.""" """Initialize a matrix with the largest possible identity matrix."""
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
weights = torch.zeros(in_dim, out_dim) weights = torch.zeros(in_dim, out_dim)
I = torch.eye(min(in_dim, out_dim)) I = torch.eye(min(in_dim, out_dim))
@ -438,6 +469,7 @@ class EyeTransformInitializer(AbstractLinearTransformInitializer):
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
"""Abstract class for all data-aware linear transform initializers.""" """Abstract class for all data-aware linear transform initializers."""
def __init__(self, def __init__(self,
data: torch.Tensor, data: torch.Tensor,
noise: float = 0.0, noise: float = 0.0,
@ -458,6 +490,7 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
"""Initialize a matrix with Eigenvectors from the data.""" """Initialize a matrix with Eigenvectors from the data."""
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
_, _, weights = torch.pca_lowrank(self.data, q=out_dim) _, _, weights = torch.pca_lowrank(self.data, q=out_dim)
return self.generate_end_hook(weights) return self.generate_end_hook(weights)

View File

@ -2,7 +2,7 @@
import torch import torch
from ..nn.activations import get_activation from prototorch.nn.activations import get_activation
# Helpers # Helpers
@ -106,6 +106,7 @@ def margin_loss(y_pred, y_true, margin=0.3):
class GLVQLoss(torch.nn.Module): class GLVQLoss(torch.nn.Module):
def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs): def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.margin = margin self.margin = margin
@ -119,6 +120,7 @@ class GLVQLoss(torch.nn.Module):
class MarginLoss(torch.nn.modules.loss._Loss): class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self, def __init__(self,
margin=0.3, margin=0.3,
size_average=None, size_average=None,
@ -132,6 +134,7 @@ class MarginLoss(torch.nn.modules.loss._Loss):
class NeuralGasEnergy(torch.nn.Module): class NeuralGasEnergy(torch.nn.Module):
def __init__(self, lm, **kwargs): def __init__(self, lm, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.lm = lm self.lm = lm
@ -152,6 +155,7 @@ class NeuralGasEnergy(torch.nn.Module):
class GrowingNeuralGasEnergy(NeuralGasEnergy): class GrowingNeuralGasEnergy(NeuralGasEnergy):
def __init__(self, topology_layer, **kwargs): def __init__(self, topology_layer, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.topology_layer = topology_layer self.topology_layer = topology_layer

View File

@ -82,23 +82,27 @@ def stratified_prod_pooling(values: torch.Tensor,
class StratifiedSumPooling(torch.nn.Module): class StratifiedSumPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_sum_pooling` function.""" """Thin wrapper over the `stratified_sum_pooling` function."""
def forward(self, values, labels): # pylint: disable=no-self-use def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_sum_pooling(values, labels) return stratified_sum_pooling(values, labels)
class StratifiedProdPooling(torch.nn.Module): class StratifiedProdPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_prod_pooling` function.""" """Thin wrapper over the `stratified_prod_pooling` function."""
def forward(self, values, labels): # pylint: disable=no-self-use def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_prod_pooling(values, labels) return stratified_prod_pooling(values, labels)
class StratifiedMinPooling(torch.nn.Module): class StratifiedMinPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_min_pooling` function.""" """Thin wrapper over the `stratified_min_pooling` function."""
def forward(self, values, labels): # pylint: disable=no-self-use def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_min_pooling(values, labels) return stratified_min_pooling(values, labels)
class StratifiedMaxPooling(torch.nn.Module): class StratifiedMaxPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_max_pooling` function.""" """Thin wrapper over the `stratified_max_pooling` function."""
def forward(self, values, labels): # pylint: disable=no-self-use def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_max_pooling(values, labels) return stratified_max_pooling(values, labels)

View File

@ -10,6 +10,7 @@ from .initializers import (
class LinearTransform(torch.nn.Module): class LinearTransform(torch.nn.Module):
def __init__( def __init__(
self, self,
in_dim: int, in_dim: int,

View File

@ -93,6 +93,7 @@ class ProtoDataset(Dataset):
class NumpyDataset(torch.utils.data.TensorDataset): class NumpyDataset(torch.utils.data.TensorDataset):
"""Create a PyTorch TensorDataset from NumPy arrays.""" """Create a PyTorch TensorDataset from NumPy arrays."""
def __init__(self, data, targets): def __init__(self, data, targets):
self.data = torch.Tensor(data) self.data = torch.Tensor(data)
self.targets = torch.LongTensor(targets) self.targets = torch.LongTensor(targets)
@ -102,6 +103,7 @@ class NumpyDataset(torch.utils.data.TensorDataset):
class CSVDataset(NumpyDataset): class CSVDataset(NumpyDataset):
"""Create a Dataset from a CSV file.""" """Create a Dataset from a CSV file."""
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0): def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
raw = np.genfromtxt( raw = np.genfromtxt(
filepath, filepath,

View File

@ -8,8 +8,13 @@ URL:
import warnings import warnings
from typing import Sequence, Union from typing import Sequence, Union
from sklearn.datasets import (load_iris, make_blobs, make_circles, from sklearn.datasets import (
make_classification, make_moons) load_iris,
make_blobs,
make_circles,
make_classification,
make_moons,
)
from prototorch.datasets.abstract import NumpyDataset from prototorch.datasets.abstract import NumpyDataset
@ -35,6 +40,7 @@ class Iris(NumpyDataset):
:param dims: select a subset of dimensions :param dims: select a subset of dimensions
""" """
def __init__(self, dims: Sequence[int] = None): def __init__(self, dims: Sequence[int] = None):
x, y = load_iris(return_X_y=True) x, y = load_iris(return_X_y=True)
if dims: if dims:
@ -49,6 +55,7 @@ class Blobs(NumpyDataset):
https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators. https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators.
""" """
def __init__(self, def __init__(self,
num_samples: int = 300, num_samples: int = 300,
num_features: int = 2, num_features: int = 2,
@ -69,6 +76,7 @@ class Random(NumpyDataset):
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy. Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
""" """
def __init__(self, def __init__(self,
num_samples: int = 300, num_samples: int = 300,
num_features: int = 2, num_features: int = 2,
@ -104,6 +112,7 @@ class Circles(NumpyDataset):
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html
""" """
def __init__(self, def __init__(self,
num_samples: int = 300, num_samples: int = 300,
noise: float = 0.3, noise: float = 0.3,
@ -126,6 +135,7 @@ class Moons(NumpyDataset):
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html
""" """
def __init__(self, def __init__(self,
num_samples: int = 300, num_samples: int = 300,
noise: float = 0.3, noise: float = 0.3,

View File

@ -9,6 +9,7 @@ def make_spiral(num_samples=500, noise=0.3):
For use in Prototorch use `prototorch.datasets.Spiral` instead. For use in Prototorch use `prototorch.datasets.Spiral` instead.
""" """
def get_samples(n, delta_t): def get_samples(n, delta_t):
points = [] points = []
for i in range(n): for i in range(n):
@ -52,6 +53,7 @@ class Spiral(torch.utils.data.TensorDataset):
:param num_samples: number of random samples :param num_samples: number of random samples
:param noise: noise added to the spirals :param noise: noise added to the spirals
""" """
def __init__(self, num_samples: int = 500, noise: float = 0.3): def __init__(self, num_samples: int = 500, noise: float = 0.3):
x, y = make_spiral(num_samples, noise) x, y = make_spiral(num_samples, noise)
super().__init__(torch.Tensor(x), torch.LongTensor(y)) super().__init__(torch.Tensor(x), torch.LongTensor(y))

View File

@ -13,6 +13,7 @@ def make_xor(num_samples=500):
class XOR(torch.utils.data.TensorDataset): class XOR(torch.utils.data.TensorDataset):
"""Exclusive-or (XOR) dataset for binary classification.""" """Exclusive-or (XOR) dataset for binary classification."""
def __init__(self, num_samples: int = 500): def __init__(self, num_samples: int = 500):
x, y = make_xor(num_samples) x, y = make_xor(num_samples)
super().__init__(x, y) super().__init__(x, y)

View File

@ -4,6 +4,7 @@ import torch
class LambdaLayer(torch.nn.Module): class LambdaLayer(torch.nn.Module):
def __init__(self, fn, name=None): def __init__(self, fn, name=None):
super().__init__() super().__init__()
self.fn = fn self.fn = fn
@ -17,6 +18,7 @@ class LambdaLayer(torch.nn.Module):
class LossLayer(torch.nn.modules.loss._Loss): class LossLayer(torch.nn.modules.loss._Loss):
def __init__(self, def __init__(self,
fn, fn,
name=None, name=None,

View File

@ -13,6 +13,32 @@ import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
def generate_mesh(
minima: torch.TensorType,
maxima: torch.TensorType,
border: float = 1.0,
resolution: int = 100,
device: torch.device = None,
):
# Apply Border
ptp = maxima - minima
shift = border * ptp
minima -= shift
maxima += shift
# Generate Mesh
minima = minima.to(device).unsqueeze(1)
maxima = maxima.to(device).unsqueeze(1)
factors = torch.linspace(0, 1, resolution, device=device)
marginals = factors * maxima + ((1 - factors) * minima)
single_dimensions = torch.meshgrid(*marginals)
mesh_input = torch.stack([dim.ravel() for dim in single_dimensions], dim=1)
return mesh_input, single_dimensions
def mesh2d(x=None, border: float = 1.0, resolution: int = 100): def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
if x is not None: if x is not None:
x_shift = border * np.ptp(x[:, 0]) x_shift = border * np.ptp(x[:, 0])

View File

@ -12,4 +12,18 @@ multi_line_output = 3
include_trailing_comma = True include_trailing_comma = True
force_grid_wrap = 3 force_grid_wrap = 3
use_parentheses = True use_parentheses = True
line_length = 79 line_length = 79
[bumpversion]
current_version = 0.7.1
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize = {major}.{minor}.{patch}
message = build: bump version {current_version} → {new_version}
[bumpversion:file:setup.py]
[bumpversion:file:./prototorch/__init__.py]
[bumpversion:file:./docs/source/conf.py]

View File

@ -29,7 +29,7 @@ DATASETS = [
"tqdm", "tqdm",
] ]
DEV = [ DEV = [
"bumpversion", "bump2version",
"pre-commit", "pre-commit",
] ]
DOCS = [ DOCS = [
@ -43,7 +43,10 @@ EXAMPLES = [
"matplotlib", "matplotlib",
"torchinfo", "torchinfo",
] ]
TESTS = ["codecov", "pytest"] TESTS = [
"flake8",
"pytest",
]
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
@ -59,7 +62,7 @@ setup(
url=PROJECT_URL, url=PROJECT_URL,
download_url=DOWNLOAD_URL, download_url=DOWNLOAD_URL,
license="MIT", license="MIT",
python_requires=">=3.6", python_requires=">=3.7,<3.10",
install_requires=INSTALL_REQUIRES, install_requires=INSTALL_REQUIRES,
extras_require={ extras_require={
"datasets": DATASETS, "datasets": DATASETS,
@ -82,7 +85,6 @@ setup(
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",

View File

@ -404,6 +404,7 @@ def test_glvq_loss_one_hot_unequal():
# Activations # Activations
class TestActivations(unittest.TestCase): class TestActivations(unittest.TestCase):
def setUp(self): def setUp(self):
self.flist = ["identity", "sigmoid_beta", "swish_beta"] self.flist = ["identity", "sigmoid_beta", "swish_beta"]
self.x = torch.randn(1024, 1) self.x = torch.randn(1024, 1)
@ -418,6 +419,7 @@ class TestActivations(unittest.TestCase):
self.assertTrue(iscallable) self.assertTrue(iscallable)
def test_callable_deserialization(self): def test_callable_deserialization(self):
def dummy(x, **kwargs): def dummy(x, **kwargs):
return x return x
@ -462,6 +464,7 @@ class TestActivations(unittest.TestCase):
# Competitions # Competitions
class TestCompetitions(unittest.TestCase): class TestCompetitions(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
@ -515,6 +518,7 @@ class TestCompetitions(unittest.TestCase):
# Pooling # Pooling
class TestPooling(unittest.TestCase): class TestPooling(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
@ -615,6 +619,7 @@ class TestPooling(unittest.TestCase):
# Distances # Distances
class TestDistances(unittest.TestCase): class TestDistances(unittest.TestCase):
def setUp(self): def setUp(self):
self.nx, self.mx = 32, 2048 self.nx, self.mx = 32, 2048
self.ny, self.my = 8, 2048 self.ny, self.my = 8, 2048

View File

@ -12,6 +12,7 @@ from prototorch.datasets.abstract import Dataset, ProtoDataset
class TestAbstract(unittest.TestCase): class TestAbstract(unittest.TestCase):
def setUp(self): def setUp(self):
self.ds = Dataset("./artifacts") self.ds = Dataset("./artifacts")
@ -28,6 +29,7 @@ class TestAbstract(unittest.TestCase):
class TestProtoDataset(unittest.TestCase): class TestProtoDataset(unittest.TestCase):
def test_download(self): def test_download(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
_ = ProtoDataset("./artifacts", download=True) _ = ProtoDataset("./artifacts", download=True)
@ -38,6 +40,7 @@ class TestProtoDataset(unittest.TestCase):
class TestNumpyDataset(unittest.TestCase): class TestNumpyDataset(unittest.TestCase):
def test_list_init(self): def test_list_init(self):
ds = pt.datasets.NumpyDataset([1], [1]) ds = pt.datasets.NumpyDataset([1], [1])
self.assertEqual(len(ds), 1) self.assertEqual(len(ds), 1)
@ -50,6 +53,7 @@ class TestNumpyDataset(unittest.TestCase):
class TestCSVDataset(unittest.TestCase): class TestCSVDataset(unittest.TestCase):
def setUp(self): def setUp(self):
data = np.random.rand(100, 4) data = np.random.rand(100, 4)
targets = np.random.randint(2, size=(100, 1)) targets = np.random.randint(2, size=(100, 1))
@ -67,12 +71,14 @@ class TestCSVDataset(unittest.TestCase):
class TestSpiral(unittest.TestCase): class TestSpiral(unittest.TestCase):
def test_init(self): def test_init(self):
ds = pt.datasets.Spiral(num_samples=10) ds = pt.datasets.Spiral(num_samples=10)
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)
class TestIris(unittest.TestCase): class TestIris(unittest.TestCase):
def setUp(self): def setUp(self):
self.ds = pt.datasets.Iris() self.ds = pt.datasets.Iris()
@ -88,24 +94,28 @@ class TestIris(unittest.TestCase):
class TestBlobs(unittest.TestCase): class TestBlobs(unittest.TestCase):
def test_size(self): def test_size(self):
ds = pt.datasets.Blobs(num_samples=10) ds = pt.datasets.Blobs(num_samples=10)
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)
class TestRandom(unittest.TestCase): class TestRandom(unittest.TestCase):
def test_size(self): def test_size(self):
ds = pt.datasets.Random(num_samples=10) ds = pt.datasets.Random(num_samples=10)
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)
class TestCircles(unittest.TestCase): class TestCircles(unittest.TestCase):
def test_size(self): def test_size(self):
ds = pt.datasets.Circles(num_samples=10) ds = pt.datasets.Circles(num_samples=10)
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)
class TestMoons(unittest.TestCase): class TestMoons(unittest.TestCase):
def test_size(self): def test_size(self):
ds = pt.datasets.Moons(num_samples=10) ds = pt.datasets.Moons(num_samples=10)
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)