Compare commits

..

12 Commits

Author SHA1 Message Date
Alexander Engelsberger
0788718c31
ci: cache test IV 2021-11-05 15:14:59 +01:00
Alexander Engelsberger
4f5c4ebe8f
ci: cache test III 2021-11-05 15:09:21 +01:00
Alexander Engelsberger
ae2a8e54ef
ci: cache test II 2021-11-05 14:59:01 +01:00
Alexander Engelsberger
d9be100c1f
ci: use pip cache in jenkins 2021-11-05 14:55:12 +01:00
Alexander Engelsberger
9d1dc7320f
ci: fix jenkinsfile 2021-11-05 14:32:37 +01:00
Alexander Engelsberger
d11ab71b7e
ci: unit tests in jenkins 2021-11-05 14:30:08 +01:00
Alexander Engelsberger
59037e1a50
ci: upgrade pip before install 2021-11-04 10:53:51 +01:00
Alexander Engelsberger
a19b99be82
ci: container debugging III 2021-11-04 10:50:44 +01:00
Alexander Engelsberger
f7e7558338
ci: container debugging II 2021-11-04 10:42:53 +01:00
Alexander Engelsberger
d57648f9d6
ci: container debugging 2021-11-04 10:41:28 +01:00
Alexander Engelsberger
d24f580bf0
ci: install dependencies with user flag 2021-11-04 09:55:58 +01:00
Jensun Ravichandran
916973c3e8 ci: migrate to jenkins 2021-11-03 16:26:32 +01:00
36 changed files with 281 additions and 491 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.7.6 current_version = 0.7.1
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

7
.ci/python310.Dockerfile Normal file
View File

@ -0,0 +1,7 @@
FROM python:3.9
RUN adduser --uid 1000 jenkins
USER jenkins
RUN mkdir -p /home/jenkins/.cache/pip

7
.ci/python36.Dockerfile Normal file
View File

@ -0,0 +1,7 @@
FROM python:3.6
RUN adduser --uid 1000 jenkins
USER jenkins
RUN mkdir -p /home/jenkins/.cache/pip

15
.codacy.yml Normal file
View File

@ -0,0 +1,15 @@
# To validate the contents of your configuration file
# run the following command in the folder where the configuration file is located:
# codacy-analysis-cli validate-configuration --directory `pwd`
# To analyse, run:
# codacy-analysis-cli analyse --tool remark-lint --directory `pwd`
---
engines:
pylintpython3:
exclude_paths:
- config/engines.yml
remark-lint:
exclude_paths:
- config/engines.yml
exclude_paths:
- 'tests/**'

2
.codecov.yml Normal file
View File

@ -0,0 +1,2 @@
comment:
require_changes: yes

View File

@ -5,71 +5,33 @@ name: tests
on: on:
push: push:
branches: [ master, dev ]
pull_request: pull_request:
branches: [master] branches: [ master ]
jobs: jobs:
style: build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- uses: pre-commit/action@v3.0.0
compatibility:
needs: style
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
os: [ubuntu-latest, windows-latest]
exclude:
- os: windows-latest
python-version: "3.8"
- os: windows-latest
python-version: "3.9"
- os: windows-latest
python-version: "3.10"
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[all]
- name: Test with pytest
run: |
pytest
publish_pypi:
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
needs: compatibility
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v2
- name: Set up Python 3.10 - name: Set up Python 3.9
uses: actions/setup-python@v4 uses: actions/setup-python@v1
with: with:
python-version: "3.11" python-version: 3.9
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[all] pip install .[all]
pip install wheel - name: Lint with flake8
- name: Build package run: |
run: python setup.py sdist bdist_wheel pip install flake8
- name: Publish a Python distribution to PyPI # stop the build if there are Python syntax errors or undefined names
uses: pypa/gh-action-pypi-publish@release/v1 flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
with: # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
user: __token__ flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
password: ${{ secrets.PYPI_API_TOKEN }} - name: Test with pytest
run: |
pip install pytest
pytest

View File

@ -3,7 +3,7 @@
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0 rev: v4.0.1
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
- id: end-of-file-fixer - id: end-of-file-fixer
@ -13,36 +13,36 @@ repos:
- id: check-case-conflict - id: check-case-conflict
- repo: https://github.com/myint/autoflake - repo: https://github.com/myint/autoflake
rev: v2.1.1 rev: v1.4
hooks: hooks:
- id: autoflake - id: autoflake
- repo: http://github.com/PyCQA/isort - repo: http://github.com/PyCQA/isort
rev: 5.12.0 rev: 5.8.0
hooks: hooks:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0 rev: v0.902
hooks: hooks:
- id: mypy - id: mypy
files: prototorch files: prototorch
additional_dependencies: [types-pkg_resources] additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.32.0 rev: v0.31.0
hooks: hooks:
- id: yapf - id: yapf
- repo: https://github.com/pre-commit/pygrep-hooks - repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0 rev: v1.9.0
hooks: hooks:
- id: python-use-type-annotations - id: python-use-type-annotations
- id: python-no-log-warn - id: python-no-log-warn
- id: python-check-blanket-noqa - id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v3.7.0 rev: v2.19.4
hooks: hooks:
- id: pyupgrade - id: pyupgrade

41
Jenkinsfile vendored Normal file
View File

@ -0,0 +1,41 @@
pipeline {
agent none
stages {
stage('Unit Tests') {
parallel {
stage('3.6'){
agent{
dockerfile {
filename 'python36.Dockerfile'
dir '.ci'
args '-v pip-cache:/home/jenkins/.cache/pip'
}
}
steps {
sh 'pip install pip --upgrade --progress-bar off'
sh 'pip install .[all] --progress-bar off'
sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
cobertura coberturaReportFile: 'reports/coverage.xml'
junit 'reports/**/*.xml'
}
}
stage('3.10'){
agent{
dockerfile {
filename 'python310.Dockerfile'
dir '.ci'
args '-v pip-cache:/home/jenkins/.cache/pip'
}
}
steps {
sh 'pip install pip --upgrade --progress-bar off'
sh 'pip install .[all] --progress-bar off'
sh '~/.local/bin/pytest -v --junitxml=reports/result.xml --cov=prototorch/ --cov-report=xml:reports/coverage.xml'
cobertura coberturaReportFile: 'reports/coverage.xml'
junit 'reports/**/*.xml'
}
}
}
}
}
}

View File

@ -1,7 +1,6 @@
MIT License MIT License
Copyright (c) 2020 Saxon Institute for Computational Intelligence and Machine Copyright (c) 2020 si-cim
Learning (SICIM)
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@ -2,9 +2,12 @@
![ProtoTorch Logo](https://prototorch.readthedocs.io/en/latest/_static/horizontal-lockup.png) ![ProtoTorch Logo](https://prototorch.readthedocs.io/en/latest/_static/horizontal-lockup.png)
[![Build Status](https://api.travis-ci.com/si-cim/prototorch.svg?branch=master)](https://travis-ci.com/github/si-cim/prototorch)
![tests](https://github.com/si-cim/prototorch/workflows/tests/badge.svg) ![tests](https://github.com/si-cim/prototorch/workflows/tests/badge.svg)
[![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch?color=yellow&label=version)](https://github.com/si-cim/prototorch/releases) [![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch?color=yellow&label=version)](https://github.com/si-cim/prototorch/releases)
[![PyPI](https://img.shields.io/pypi/v/prototorch)](https://pypi.org/project/prototorch/) [![PyPI](https://img.shields.io/pypi/v/prototorch)](https://pypi.org/project/prototorch/)
[![codecov](https://codecov.io/gh/si-cim/prototorch/branch/master/graph/badge.svg)](https://codecov.io/gh/si-cim/prototorch)
[![Codacy Badge](https://api.codacy.com/project/badge/Grade/76273904bf9343f0a8b29cd8aca242e7)](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=si-cim/prototorch&amp;utm_campaign=Badge_Grade)
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch)](https://github.com/si-cim/prototorch/blob/master/LICENSE) [![GitHub license](https://img.shields.io/github/license/si-cim/prototorch)](https://github.com/si-cim/prototorch/blob/master/LICENSE)
*Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow) *Tensorflow users, see:* [ProtoFlow](https://github.com/si-cim/protoflow)

46
deprecated.travis.yml Normal file
View File

@ -0,0 +1,46 @@
dist: bionic
sudo: false
language: python
python:
- 3.9
- 3.8
- 3.7
- 3.6
cache:
directories:
- "$HOME/.cache/pip"
- "./tests/artifacts"
- "$HOME/datasets"
install:
- pip install .[all] --progress-bar off
# Generate code coverage report
script:
- coverage run -m pytest
# Push the results to codecov
after_success:
- bash <(curl -s https://codecov.io/bash)
# Publish on PyPI
jobs:
include:
- stage: build
python: 3.9
script: echo "Starting Pypi build"
deploy:
provider: pypi
username: __token__
distributions: "sdist bdist_wheel"
password:
secure: rVQNCxKIuiEtMz4zLSsjdt6spG7cf3miKN5eqjxZfcELALHxAV4w/+CideQObOn3u9emmxb87R9XWKcogqK2MXqnuIcY4mWg7HUqaip1bhz/4YiVXjFILcG6itjX9IUF1DrtjKKRk6xryucSZcEB7yTcXz1hQTb768KWlLlKOVTRNwr7j07eyeafexz/L2ANQCqfOZgS4b0k2AMeDBRPykPULtyeneEFlb6MJZ2MxeqtTNVK4b/6VsQSZwQ9jGJNGWonn5Y287gHmzvEcymSJogTe2taxGBWawPnOsibws9v88DEAHdsEvYdnqEE3hFl0R5La2Lkjd8CjNUYegxioQ57i3WNS3iksq10ZLMCbH29lb9YPG7r6Y8z9H85735kV2gKLdf+o7SPS03TRgjSZKN6pn4pLG0VWkxC6l8VfLuJnRNTHX4g6oLQwOWIBbxybn9Zw/yLjAXAJNgBHt5v86H6Jfi1Va4AhEV6itkoH9IM3/uDhrE/mmorqyVled/CPNtBWNTyoDevLNxMUDnbuhH0JzLki+VOjKnTxEfq12JB8X9faFG5BjvU9oGjPPewrp5DGGzg6KDra7dikciWUxE1eTFFDhMyG1CFGcjKlDvlAGHyI6Kih35egGUeq+N/pitr2330ftM9Dm4rWpOTxPyCI89bXKssx/MgmLG7kSM=
on:
tags: true
skip_existing: true
# The password is encrypted with:
# `cd prototorch && travis encrypt your-pypi-api-token --add deploy.password`
# See https://docs.travis-ci.com/user/deployment/pypi and
# https://github.com/travis-ci/travis.rb#installation
# for more details
# Note: The encrypt command does not work well in ZSH.

View File

@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.7.6" release = "0.7.1"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
@ -120,7 +120,7 @@ html_css_files = [
# -- Options for HTMLHelp output ------------------------------------------ # -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = "prototorchdoc" htmlhelp_basename = "protoflowdoc"
# -- Options for LaTeX output --------------------------------------------- # -- Options for LaTeX output ---------------------------------------------

View File

@ -1,7 +1,5 @@
"""ProtoTorch CBC example using 2D Iris data.""" """ProtoTorch CBC example using 2D Iris data."""
import logging
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
@ -9,7 +7,6 @@ import prototorch as pt
class CBC(torch.nn.Module): class CBC(torch.nn.Module):
def __init__(self, data, **kwargs): def __init__(self, data, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.components_layer = pt.components.ReasoningComponents( self.components_layer = pt.components.ReasoningComponents(
@ -26,7 +23,6 @@ class CBC(torch.nn.Module):
class VisCBC2D(): class VisCBC2D():
def __init__(self, model, data): def __init__(self, model, data):
self.model = model self.model = model
self.x_train, self.y_train = pt.utils.parse_data_arg(data) self.x_train, self.y_train = pt.utils.parse_data_arg(data)
@ -36,7 +32,7 @@ class VisCBC2D():
self.resolution = 100 self.resolution = 100
self.cmap = "viridis" self.cmap = "viridis"
def on_train_epoch_end(self): def on_epoch_end(self):
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
_components = self.model.components_layer._components.detach() _components = self.model.components_layer._components.detach()
ax = self.fig.gca() ax = self.fig.gca()
@ -96,5 +92,5 @@ if __name__ == "__main__":
correct += (y_pred.argmax(1) == y).float().sum(0) correct += (y_pred.argmax(1) == y).float().sum(0)
acc = 100 * correct / len(train_ds) acc = 100 * correct / len(train_ds)
logging.info(f"Epoch: {epoch} Accuracy: {acc:05.02f}%") print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
vis.on_train_epoch_end() vis.on_epoch_end()

View File

@ -1,76 +0,0 @@
"""ProtoTorch GMLVQ example using Iris data."""
import torch
import prototorch as pt
class GMLVQ(torch.nn.Module):
"""
Implementation of Generalized Matrix Learning Vector Quantization.
"""
def __init__(self, data, **kwargs):
super().__init__(**kwargs)
self.components_layer = pt.components.LabeledComponents(
distribution=[1, 1, 1],
components_initializer=pt.initializers.SMCI(data, noise=0.1),
)
self.backbone = pt.transforms.Omega(
len(data[0][0]),
len(data[0][0]),
pt.initializers.RandomLinearTransformInitializer(),
)
def forward(self, data):
"""
Forward function that returns a tuple of dissimilarities and label information.
Feed into GLVQLoss to get a complete GMLVQ model.
"""
components, label = self.components_layer()
latent_x = self.backbone(data)
latent_components = self.backbone(components)
distance = pt.distances.squared_euclidean_distance(
latent_x, latent_components)
return distance, label
def predict(self, data):
"""
The GMLVQ has a modified prediction step, where a competition layer is applied.
"""
components, label = self.components_layer()
distance = pt.distances.squared_euclidean_distance(data, components)
winning_label = pt.competitions.wtac(distance, label)
return winning_label
if __name__ == "__main__":
train_ds = pt.datasets.Iris()
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32)
model = GMLVQ(train_ds)
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)
criterion = pt.losses.GLVQLoss()
for epoch in range(200):
correct = 0.0
for x, y in train_loader:
d, labels = model(x)
loss = criterion(d, y, labels).mean(0)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
y_pred = model.predict(x)
correct += (y_pred == y).float().sum(0)
acc = 100 * correct / len(train_ds)
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")

View File

@ -1,23 +1,28 @@
"""ProtoTorch package""" """ProtoTorch package"""
import pkgutil import pkgutil
from typing import List
import pkg_resources import pkg_resources
from . import datasets # noqa: F401 from . import (
from . import nn # noqa: F401 datasets,
from . import utils # noqa: F401 nn,
from .core import competitions # noqa: F401 utils,
from .core import components # noqa: F401 )
from .core import distances # noqa: F401 from .core import (
from .core import initializers # noqa: F401 competitions,
from .core import losses # noqa: F401 components,
from .core import pooling # noqa: F401 distances,
from .core import similarities # noqa: F401 initializers,
from .core import transforms # noqa: F401 losses,
pooling,
similarities,
transforms,
)
# Core Setup # Core Setup
__version__ = "0.7.6" __version__ = "0.7.1"
__all_core__ = [ __all_core__ = [
"competitions", "competitions",
@ -35,7 +40,7 @@ __all_core__ = [
] ]
# Plugin Loader # Plugin Loader
__path__ = pkgutil.extend_path(__path__, __name__) __path__: List[str] = pkgutil.extend_path(__path__, __name__)
def discover_plugins(): def discover_plugins():

View File

@ -38,7 +38,7 @@ def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
pk = A pk = A
nk = (1 - A) * B nk = (1 - A) * B
numerator = (detections @ (pk - nk).T) + nk.sum(1) numerator = (detections @ (pk - nk).T) + nk.sum(1)
probs = numerator / ((pk + nk).sum(1) + 1e-8) probs = numerator / (pk + nk).sum(1)
return probs return probs
@ -48,7 +48,6 @@ class WTAC(torch.nn.Module):
Thin wrapper over the `wtac` function. Thin wrapper over the `wtac` function.
""" """
def forward(self, distances, labels): # pylint: disable=no-self-use def forward(self, distances, labels): # pylint: disable=no-self-use
return wtac(distances, labels) return wtac(distances, labels)
@ -59,7 +58,6 @@ class LTAC(torch.nn.Module):
Thin wrapper over the `wtac` function. Thin wrapper over the `wtac` function.
""" """
def forward(self, probs, labels): # pylint: disable=no-self-use def forward(self, probs, labels): # pylint: disable=no-self-use
return wtac(-1.0 * probs, labels) return wtac(-1.0 * probs, labels)
@ -70,7 +68,6 @@ class KNNC(torch.nn.Module):
Thin wrapper over the `knnc` function. Thin wrapper over the `knnc` function.
""" """
def __init__(self, k=1, **kwargs): def __init__(self, k=1, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.k = k self.k = k
@ -88,6 +85,5 @@ class CBCC(torch.nn.Module):
Thin wrapper over the `cbcc` function. Thin wrapper over the `cbcc` function.
""" """
def forward(self, detections, reasonings): # pylint: disable=no-self-use def forward(self, detections, reasonings): # pylint: disable=no-self-use
return cbcc(detections, reasonings) return cbcc(detections, reasonings)

View File

@ -6,8 +6,7 @@ from typing import Union
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from prototorch.utils import parse_distribution from ..utils import parse_distribution
from .initializers import ( from .initializers import (
AbstractClassAwareCompInitializer, AbstractClassAwareCompInitializer,
AbstractComponentsInitializer, AbstractComponentsInitializer,
@ -64,7 +63,6 @@ def get_cikwargs(init, distribution):
class AbstractComponents(torch.nn.Module): class AbstractComponents(torch.nn.Module):
"""Abstract class for all components modules.""" """Abstract class for all components modules."""
@property @property
def num_components(self): def num_components(self):
"""Current number of components.""" """Current number of components."""
@ -87,7 +85,6 @@ class AbstractComponents(torch.nn.Module):
class Components(AbstractComponents): class Components(AbstractComponents):
"""A set of adaptable Tensors.""" """A set of adaptable Tensors."""
def __init__(self, num_components: int, def __init__(self, num_components: int,
initializer: AbstractComponentsInitializer): initializer: AbstractComponentsInitializer):
super().__init__() super().__init__()
@ -115,7 +112,6 @@ class Components(AbstractComponents):
class AbstractLabels(torch.nn.Module): class AbstractLabels(torch.nn.Module):
"""Abstract class for all labels modules.""" """Abstract class for all labels modules."""
@property @property
def labels(self): def labels(self):
return self._labels.cpu() return self._labels.cpu()
@ -156,7 +152,6 @@ class AbstractLabels(torch.nn.Module):
class Labels(AbstractLabels): class Labels(AbstractLabels):
"""A set of standalone labels.""" """A set of standalone labels."""
def __init__(self, def __init__(self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],
initializer: AbstractLabelsInitializer = LabelsInitializer()): initializer: AbstractLabelsInitializer = LabelsInitializer()):
@ -187,7 +182,6 @@ class Labels(AbstractLabels):
class LabeledComponents(AbstractComponents): class LabeledComponents(AbstractComponents):
"""A set of adaptable components and corresponding unadaptable labels.""" """A set of adaptable components and corresponding unadaptable labels."""
def __init__( def __init__(
self, self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],
@ -255,7 +249,6 @@ class Reasonings(torch.nn.Module):
The `reasonings` tensor is of shape [num_components, num_classes, 2]. The `reasonings` tensor is of shape [num_components, num_classes, 2].
""" """
def __init__( def __init__(
self, self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],
@ -315,7 +308,6 @@ class ReasoningComponents(AbstractComponents):
three element probability distribution. three element probability distribution.
""" """
def __init__( def __init__(
self, self,
distribution: Union[dict, list, tuple], distribution: Union[dict, list, tuple],

View File

@ -11,7 +11,7 @@ def squared_euclidean_distance(x, y):
**Alias:** **Alias:**
``prototorch.functions.distances.sed`` ``prototorch.functions.distances.sed``
""" """
x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
expanded_x = x.unsqueeze(dim=1) expanded_x = x.unsqueeze(dim=1)
batchwise_difference = y - expanded_x batchwise_difference = y - expanded_x
differences_raised = torch.pow(batchwise_difference, 2) differences_raised = torch.pow(batchwise_difference, 2)
@ -27,14 +27,14 @@ def euclidean_distance(x, y):
:returns: Distance Tensor of shape :math:`X \times Y` :returns: Distance Tensor of shape :math:`X \times Y`
:rtype: `torch.tensor` :rtype: `torch.tensor`
""" """
x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
distances_raised = squared_euclidean_distance(x, y) distances_raised = squared_euclidean_distance(x, y)
distances = torch.sqrt(distances_raised) distances = torch.sqrt(distances_raised)
return distances return distances
def euclidean_distance_v2(x, y): def euclidean_distance_v2(x, y):
x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
diff = y - x.unsqueeze(1) diff = y - x.unsqueeze(1)
pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt() pairwise_distances = (diff @ diff.permute((0, 2, 1))).sqrt()
# Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the # Passing `dim1=-2` and `dim2=-1` to `diagonal()` takes the
@ -54,7 +54,7 @@ def lpnorm_distance(x, y, p):
:param p: p parameter of the lp norm :param p: p parameter of the lp norm
""" """
x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
distances = torch.cdist(x, y, p=p) distances = torch.cdist(x, y, p=p)
return distances return distances
@ -66,7 +66,7 @@ def omega_distance(x, y, omega):
:param `torch.tensor` omega: Two dimensional matrix :param `torch.tensor` omega: Two dimensional matrix
""" """
x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
projected_x = x @ omega projected_x = x @ omega
projected_y = y @ omega projected_y = y @ omega
distances = squared_euclidean_distance(projected_x, projected_y) distances = squared_euclidean_distance(projected_x, projected_y)
@ -80,7 +80,7 @@ def lomega_distance(x, y, omegas):
:param `torch.tensor` omegas: Three dimensional matrix :param `torch.tensor` omegas: Three dimensional matrix
""" """
x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
projected_x = x @ omegas projected_x = x @ omegas
projected_y = torch.diagonal(y @ omegas).T projected_y = torch.diagonal(y @ omegas).T
expanded_y = torch.unsqueeze(projected_y, dim=1) expanded_y = torch.unsqueeze(projected_y, dim=1)

View File

@ -11,7 +11,7 @@ from typing import (
import torch import torch
from prototorch.utils import parse_data_arg, parse_distribution from ..utils import parse_data_arg, parse_distribution
# Components # Components
@ -26,18 +26,11 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
Use this to 'generate' pre-initialized components elsewhere. Use this to 'generate' pre-initialized components elsewhere.
""" """
def __init__(self, components): def __init__(self, components):
self.components = components self.components = components
def generate(self, num_components: int = 0): def generate(self, num_components: int = 0):
"""Ignore `num_components` and simply return `self.components`.""" """Ignore `num_components` and simply return `self.components`."""
provided_num_components = len(self.components)
if provided_num_components != num_components:
wmsg = f"The number of components ({provided_num_components}) " \
f"provided to {self.__class__.__name__} " \
f"does not match the expected number ({num_components})."
warnings.warn(wmsg)
if not isinstance(self.components, torch.Tensor): if not isinstance(self.components, torch.Tensor):
wmsg = f"Converting components to {torch.Tensor}..." wmsg = f"Converting components to {torch.Tensor}..."
warnings.warn(wmsg) warnings.warn(wmsg)
@ -47,7 +40,6 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
class ShapeAwareCompInitializer(AbstractComponentsInitializer): class ShapeAwareCompInitializer(AbstractComponentsInitializer):
"""Abstract class for all dimension-aware components initializers.""" """Abstract class for all dimension-aware components initializers."""
def __init__(self, shape: Union[Iterable, int]): def __init__(self, shape: Union[Iterable, int]):
if isinstance(shape, Iterable): if isinstance(shape, Iterable):
self.component_shape = tuple(shape) self.component_shape = tuple(shape)
@ -61,7 +53,6 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer):
class ZerosCompInitializer(ShapeAwareCompInitializer): class ZerosCompInitializer(ShapeAwareCompInitializer):
"""Generate zeros corresponding to the components shape.""" """Generate zeros corresponding to the components shape."""
def generate(self, num_components: int): def generate(self, num_components: int):
components = torch.zeros((num_components, ) + self.component_shape) components = torch.zeros((num_components, ) + self.component_shape)
return components return components
@ -69,7 +60,6 @@ class ZerosCompInitializer(ShapeAwareCompInitializer):
class OnesCompInitializer(ShapeAwareCompInitializer): class OnesCompInitializer(ShapeAwareCompInitializer):
"""Generate ones corresponding to the components shape.""" """Generate ones corresponding to the components shape."""
def generate(self, num_components: int): def generate(self, num_components: int):
components = torch.ones((num_components, ) + self.component_shape) components = torch.ones((num_components, ) + self.component_shape)
return components return components
@ -77,7 +67,6 @@ class OnesCompInitializer(ShapeAwareCompInitializer):
class FillValueCompInitializer(OnesCompInitializer): class FillValueCompInitializer(OnesCompInitializer):
"""Generate components with the provided `fill_value`.""" """Generate components with the provided `fill_value`."""
def __init__(self, shape, fill_value: float = 1.0): def __init__(self, shape, fill_value: float = 1.0):
super().__init__(shape) super().__init__(shape)
self.fill_value = fill_value self.fill_value = fill_value
@ -90,7 +79,6 @@ class FillValueCompInitializer(OnesCompInitializer):
class UniformCompInitializer(OnesCompInitializer): class UniformCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a continuous uniform distribution.""" """Generate components by sampling from a continuous uniform distribution."""
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0): def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
super().__init__(shape) super().__init__(shape)
self.minimum = minimum self.minimum = minimum
@ -105,7 +93,6 @@ class UniformCompInitializer(OnesCompInitializer):
class RandomNormalCompInitializer(OnesCompInitializer): class RandomNormalCompInitializer(OnesCompInitializer):
"""Generate components by sampling from a standard normal distribution.""" """Generate components by sampling from a standard normal distribution."""
def __init__(self, shape, shift=0.0, scale=1.0): def __init__(self, shape, shift=0.0, scale=1.0):
super().__init__(shape) super().__init__(shape)
self.shift = shift self.shift = shift
@ -126,7 +113,6 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
`data` has to be a torch tensor. `data` has to be a torch tensor.
""" """
def __init__(self, def __init__(self,
data: torch.Tensor, data: torch.Tensor,
noise: float = 0.0, noise: float = 0.0,
@ -151,7 +137,6 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
class DataAwareCompInitializer(AbstractDataAwareCompInitializer): class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
"""'Generate' the components from the provided data.""" """'Generate' the components from the provided data."""
def generate(self, num_components: int = 0): def generate(self, num_components: int = 0):
"""Ignore `num_components` and simply return transformed `self.data`.""" """Ignore `num_components` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data) components = self.generate_end_hook(self.data)
@ -160,7 +145,6 @@ class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
class SelectionCompInitializer(AbstractDataAwareCompInitializer): class SelectionCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by uniformly sampling from the provided data.""" """Generate components by uniformly sampling from the provided data."""
def generate(self, num_components: int): def generate(self, num_components: int):
indices = torch.LongTensor(num_components).random_(0, len(self.data)) indices = torch.LongTensor(num_components).random_(0, len(self.data))
samples = self.data[indices] samples = self.data[indices]
@ -170,7 +154,6 @@ class SelectionCompInitializer(AbstractDataAwareCompInitializer):
class MeanCompInitializer(AbstractDataAwareCompInitializer): class MeanCompInitializer(AbstractDataAwareCompInitializer):
"""Generate components by computing the mean of the provided data.""" """Generate components by computing the mean of the provided data."""
def generate(self, num_components: int): def generate(self, num_components: int):
mean = self.data.mean(dim=0) mean = self.data.mean(dim=0)
repeat_dim = [num_components] + [1] * len(mean.shape) repeat_dim = [num_components] + [1] * len(mean.shape)
@ -189,7 +172,6 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
target tensors. target tensors.
""" """
def __init__(self, def __init__(self,
data, data,
noise: float = 0.0, noise: float = 0.0,
@ -217,7 +199,6 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer): class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
"""'Generate' components from provided data and requested distribution.""" """'Generate' components from provided data and requested distribution."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
"""Ignore `distribution` and simply return transformed `self.data`.""" """Ignore `distribution` and simply return transformed `self.data`."""
components = self.generate_end_hook(self.data) components = self.generate_end_hook(self.data)
@ -226,7 +207,6 @@ class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer): class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
"""Abstract class for all stratified components initializers.""" """Abstract class for all stratified components initializers."""
@property @property
@abstractmethod @abstractmethod
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]: def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
@ -237,8 +217,6 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
components = torch.tensor([]) components = torch.tensor([])
for k, v in distribution.items(): for k, v in distribution.items():
stratified_data = self.data[self.targets == k] stratified_data = self.data[self.targets == k]
if len(stratified_data) == 0:
raise ValueError(f"No data available for class {k}.")
initializer = self.subinit_type( initializer = self.subinit_type(
stratified_data, stratified_data,
noise=self.noise, noise=self.noise,
@ -251,7 +229,6 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer): class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
"""Generate components using stratified sampling from the provided data.""" """Generate components using stratified sampling from the provided data."""
@property @property
def subinit_type(self): def subinit_type(self):
return SelectionCompInitializer return SelectionCompInitializer
@ -259,7 +236,6 @@ class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer): class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
"""Generate components at stratified means of the provided data.""" """Generate components at stratified means of the provided data."""
@property @property
def subinit_type(self): def subinit_type(self):
return MeanCompInitializer return MeanCompInitializer
@ -268,7 +244,6 @@ class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
# Labels # Labels
class AbstractLabelsInitializer(ABC): class AbstractLabelsInitializer(ABC):
"""Abstract class for all labels initializers.""" """Abstract class for all labels initializers."""
@abstractmethod @abstractmethod
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
... ...
@ -280,7 +255,6 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
Use this to 'generate' pre-initialized labels elsewhere. Use this to 'generate' pre-initialized labels elsewhere.
""" """
def __init__(self, labels): def __init__(self, labels):
self.labels = labels self.labels = labels
@ -299,7 +273,6 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
class DataAwareLabelsInitializer(AbstractLabelsInitializer): class DataAwareLabelsInitializer(AbstractLabelsInitializer):
"""'Generate' the labels from a torch Dataset.""" """'Generate' the labels from a torch Dataset."""
def __init__(self, data): def __init__(self, data):
self.data, self.targets = parse_data_arg(data) self.data, self.targets = parse_data_arg(data)
@ -310,7 +283,6 @@ class DataAwareLabelsInitializer(AbstractLabelsInitializer):
class LabelsInitializer(AbstractLabelsInitializer): class LabelsInitializer(AbstractLabelsInitializer):
"""Generate labels from `distribution`.""" """Generate labels from `distribution`."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution) distribution = parse_distribution(distribution)
labels_list = [] labels_list = []
@ -322,7 +294,6 @@ class LabelsInitializer(AbstractLabelsInitializer):
class OneHotLabelsInitializer(LabelsInitializer): class OneHotLabelsInitializer(LabelsInitializer):
"""Generate one-hot-encoded labels from `distribution`.""" """Generate one-hot-encoded labels from `distribution`."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
distribution = parse_distribution(distribution) distribution = parse_distribution(distribution)
num_classes = len(distribution.keys()) num_classes = len(distribution.keys())
@ -341,7 +312,6 @@ def compute_distribution_shape(distribution):
class AbstractReasoningsInitializer(ABC): class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers.""" """Abstract class for all reasonings initializers."""
def __init__(self, components_first: bool = True): def __init__(self, components_first: bool = True):
self.components_first = components_first self.components_first = components_first
@ -362,7 +332,6 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
Use this to 'generate' pre-initialized reasonings elsewhere. Use this to 'generate' pre-initialized reasonings elsewhere.
""" """
def __init__(self, reasonings, **kwargs): def __init__(self, reasonings, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.reasonings = reasonings self.reasonings = reasonings
@ -380,7 +349,6 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
class ZerosReasoningsInitializer(AbstractReasoningsInitializer): class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with zeros.""" """Reasonings are all initialized with zeros."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
shape = compute_distribution_shape(distribution) shape = compute_distribution_shape(distribution)
reasonings = torch.zeros(*shape) reasonings = torch.zeros(*shape)
@ -390,7 +358,6 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
class OnesReasoningsInitializer(AbstractReasoningsInitializer): class OnesReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are all initialized with ones.""" """Reasonings are all initialized with ones."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
shape = compute_distribution_shape(distribution) shape = compute_distribution_shape(distribution)
reasonings = torch.ones(*shape) reasonings = torch.ones(*shape)
@ -400,7 +367,6 @@ class OnesReasoningsInitializer(AbstractReasoningsInitializer):
class RandomReasoningsInitializer(AbstractReasoningsInitializer): class RandomReasoningsInitializer(AbstractReasoningsInitializer):
"""Reasonings are randomly initialized.""" """Reasonings are randomly initialized."""
def __init__(self, minimum=0.4, maximum=0.6, **kwargs): def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.minimum = minimum self.minimum = minimum
@ -415,7 +381,6 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
"""Each component reasons positively for exactly one class.""" """Each component reasons positively for exactly one class."""
def generate(self, distribution: Union[dict, list, tuple]): def generate(self, distribution: Union[dict, list, tuple]):
num_components, num_classes, _ = compute_distribution_shape( num_components, num_classes, _ = compute_distribution_shape(
distribution) distribution)
@ -434,7 +399,6 @@ class AbstractTransformInitializer(ABC):
class AbstractLinearTransformInitializer(AbstractTransformInitializer): class AbstractLinearTransformInitializer(AbstractTransformInitializer):
"""Abstract class for all linear transform initializers.""" """Abstract class for all linear transform initializers."""
def __init__(self, out_dim_first: bool = False): def __init__(self, out_dim_first: bool = False):
self.out_dim_first = out_dim_first self.out_dim_first = out_dim_first
@ -451,7 +415,6 @@ class AbstractLinearTransformInitializer(AbstractTransformInitializer):
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer): class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with zeros.""" """Initialize a matrix with zeros."""
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
weights = torch.zeros(in_dim, out_dim) weights = torch.zeros(in_dim, out_dim)
return self.generate_end_hook(weights) return self.generate_end_hook(weights)
@ -459,23 +422,13 @@ class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer): class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with ones.""" """Initialize a matrix with ones."""
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
weights = torch.ones(in_dim, out_dim) weights = torch.ones(in_dim, out_dim)
return self.generate_end_hook(weights) return self.generate_end_hook(weights)
class RandomLinearTransformInitializer(AbstractLinearTransformInitializer): class EyeTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with random values."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.rand(in_dim, out_dim)
return self.generate_end_hook(weights)
class EyeLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with the largest possible identity matrix.""" """Initialize a matrix with the largest possible identity matrix."""
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
weights = torch.zeros(in_dim, out_dim) weights = torch.zeros(in_dim, out_dim)
I = torch.eye(min(in_dim, out_dim)) I = torch.eye(min(in_dim, out_dim))
@ -485,7 +438,6 @@ class EyeLinearTransformInitializer(AbstractLinearTransformInitializer):
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
"""Abstract class for all data-aware linear transform initializers.""" """Abstract class for all data-aware linear transform initializers."""
def __init__(self, def __init__(self,
data: torch.Tensor, data: torch.Tensor,
noise: float = 0.0, noise: float = 0.0,
@ -506,19 +458,11 @@ class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
class PCALinearTransformInitializer(AbstractDataAwareLTInitializer): class PCALinearTransformInitializer(AbstractDataAwareLTInitializer):
"""Initialize a matrix with Eigenvectors from the data.""" """Initialize a matrix with Eigenvectors from the data."""
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
_, _, weights = torch.pca_lowrank(self.data, q=out_dim) _, _, weights = torch.pca_lowrank(self.data, q=out_dim)
return self.generate_end_hook(weights) return self.generate_end_hook(weights)
class LiteralLinearTransformInitializer(AbstractDataAwareLTInitializer):
"""'Generate' the provided weights."""
def generate(self, in_dim: int, out_dim: int):
return self.generate_end_hook(self.data)
# Aliases - Components # Aliases - Components
CACI = ClassAwareCompInitializer CACI = ClassAwareCompInitializer
DACI = DataAwareCompInitializer DACI = DataAwareCompInitializer
@ -547,9 +491,7 @@ RRI = RandomReasoningsInitializer
ZRI = ZerosReasoningsInitializer ZRI = ZerosReasoningsInitializer
# Aliases - Transforms # Aliases - Transforms
ELTI = Eye = EyeLinearTransformInitializer Eye = EyeTransformInitializer
OLTI = OnesLinearTransformInitializer OLTI = OnesLinearTransformInitializer
RLTI = RandomLinearTransformInitializer
ZLTI = ZerosLinearTransformInitializer ZLTI = ZerosLinearTransformInitializer
PCALTI = PCALinearTransformInitializer PCALTI = PCALinearTransformInitializer
LLTI = LiteralLinearTransformInitializer

View File

@ -2,7 +2,7 @@
import torch import torch
from prototorch.nn.activations import get_activation from ..nn.activations import get_activation
# Helpers # Helpers
@ -106,31 +106,19 @@ def margin_loss(y_pred, y_true, margin=0.3):
class GLVQLoss(torch.nn.Module): class GLVQLoss(torch.nn.Module):
def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs):
def __init__(self,
margin=0.0,
transfer_fn="identity",
beta=10,
add_dp=False,
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.margin = margin self.margin = margin
self.transfer_fn = get_activation(transfer_fn) self.transfer_fn = get_activation(transfer_fn)
self.beta = torch.tensor(beta) self.beta = torch.tensor(beta)
self.add_dp = add_dp
def forward(self, outputs, targets, plabels): def forward(self, outputs, targets, plabels):
# mu = glvq_loss(outputs, targets, plabels) mu = glvq_loss(outputs, targets, prototype_labels=plabels)
dp, dm = _get_dp_dm(outputs, targets, plabels)
mu = (dp - dm) / (dp + dm)
if self.add_dp:
mu = mu + dp
batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta) batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
return batch_loss.sum() return batch_loss.sum()
class MarginLoss(torch.nn.modules.loss._Loss): class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self, def __init__(self,
margin=0.3, margin=0.3,
size_average=None, size_average=None,
@ -144,7 +132,6 @@ class MarginLoss(torch.nn.modules.loss._Loss):
class NeuralGasEnergy(torch.nn.Module): class NeuralGasEnergy(torch.nn.Module):
def __init__(self, lm, **kwargs): def __init__(self, lm, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.lm = lm self.lm = lm
@ -165,7 +152,6 @@ class NeuralGasEnergy(torch.nn.Module):
class GrowingNeuralGasEnergy(NeuralGasEnergy): class GrowingNeuralGasEnergy(NeuralGasEnergy):
def __init__(self, topology_layer, **kwargs): def __init__(self, topology_layer, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.topology_layer = topology_layer self.topology_layer = topology_layer

View File

@ -82,27 +82,23 @@ def stratified_prod_pooling(values: torch.Tensor,
class StratifiedSumPooling(torch.nn.Module): class StratifiedSumPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_sum_pooling` function.""" """Thin wrapper over the `stratified_sum_pooling` function."""
def forward(self, values, labels): # pylint: disable=no-self-use def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_sum_pooling(values, labels) return stratified_sum_pooling(values, labels)
class StratifiedProdPooling(torch.nn.Module): class StratifiedProdPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_prod_pooling` function.""" """Thin wrapper over the `stratified_prod_pooling` function."""
def forward(self, values, labels): # pylint: disable=no-self-use def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_prod_pooling(values, labels) return stratified_prod_pooling(values, labels)
class StratifiedMinPooling(torch.nn.Module): class StratifiedMinPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_min_pooling` function.""" """Thin wrapper over the `stratified_min_pooling` function."""
def forward(self, values, labels): # pylint: disable=no-self-use def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_min_pooling(values, labels) return stratified_min_pooling(values, labels)
class StratifiedMaxPooling(torch.nn.Module): class StratifiedMaxPooling(torch.nn.Module):
"""Thin wrapper over the `stratified_max_pooling` function.""" """Thin wrapper over the `stratified_max_pooling` function."""
def forward(self, values, labels): # pylint: disable=no-self-use def forward(self, values, labels): # pylint: disable=no-self-use
return stratified_max_pooling(values, labels) return stratified_max_pooling(values, labels)

View File

@ -21,7 +21,7 @@ def cosine_similarity(x, y):
Expected dimension of x is 2. Expected dimension of x is 2.
Expected dimension of y is 2. Expected dimension of y is 2.
""" """
x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
norm_x = x.pow(2).sum(1).sqrt() norm_x = x.pow(2).sum(1).sqrt()
norm_y = y.pow(2).sum(1).sqrt() norm_y = y.pow(2).sum(1).sqrt()
norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T norm_mat = norm_x.unsqueeze(-1) @ norm_y.unsqueeze(-1).T

View File

@ -5,18 +5,17 @@ from torch.nn.parameter import Parameter
from .initializers import ( from .initializers import (
AbstractLinearTransformInitializer, AbstractLinearTransformInitializer,
EyeLinearTransformInitializer, EyeTransformInitializer,
) )
class LinearTransform(torch.nn.Module): class LinearTransform(torch.nn.Module):
def __init__( def __init__(
self, self,
in_dim: int, in_dim: int,
out_dim: int, out_dim: int,
initializer: initializer:
AbstractLinearTransformInitializer = EyeLinearTransformInitializer()): AbstractLinearTransformInitializer = EyeTransformInitializer()):
super().__init__() super().__init__()
self.set_weights(in_dim, out_dim, initializer) self.set_weights(in_dim, out_dim, initializer)
@ -32,15 +31,12 @@ class LinearTransform(torch.nn.Module):
in_dim: int, in_dim: int,
out_dim: int, out_dim: int,
initializer: initializer:
AbstractLinearTransformInitializer = EyeLinearTransformInitializer()): AbstractLinearTransformInitializer = EyeTransformInitializer()):
weights = initializer.generate(in_dim, out_dim) weights = initializer.generate(in_dim, out_dim)
self._register_weights(weights) self._register_weights(weights)
def forward(self, x): def forward(self, x):
return x @ self._weights return x @ self.weights
def extra_repr(self):
return f"weights: (shape: {tuple(self._weights.shape)})"
# Aliases # Aliases

View File

@ -20,7 +20,7 @@ class Dataset(torch.utils.data.Dataset):
_repr_indent = 2 _repr_indent = 2
def __init__(self, root): def __init__(self, root):
if isinstance(root, str): if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root) root = os.path.expanduser(root)
self.root = root self.root = root
@ -93,7 +93,6 @@ class ProtoDataset(Dataset):
class NumpyDataset(torch.utils.data.TensorDataset): class NumpyDataset(torch.utils.data.TensorDataset):
"""Create a PyTorch TensorDataset from NumPy arrays.""" """Create a PyTorch TensorDataset from NumPy arrays."""
def __init__(self, data, targets): def __init__(self, data, targets):
self.data = torch.Tensor(data) self.data = torch.Tensor(data)
self.targets = torch.LongTensor(targets) self.targets = torch.LongTensor(targets)
@ -103,7 +102,6 @@ class NumpyDataset(torch.utils.data.TensorDataset):
class CSVDataset(NumpyDataset): class CSVDataset(NumpyDataset):
"""Create a Dataset from a CSV file.""" """Create a Dataset from a CSV file."""
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0): def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
raw = np.genfromtxt( raw = np.genfromtxt(
filepath, filepath,

View File

@ -5,18 +5,11 @@ URL:
""" """
from __future__ import annotations
import warnings import warnings
from typing import Sequence from typing import Sequence, Union
from sklearn.datasets import ( from sklearn.datasets import (load_iris, make_blobs, make_circles,
load_iris, make_classification, make_moons)
make_blobs,
make_circles,
make_classification,
make_moons,
)
from prototorch.datasets.abstract import NumpyDataset from prototorch.datasets.abstract import NumpyDataset
@ -42,10 +35,9 @@ class Iris(NumpyDataset):
:param dims: select a subset of dimensions :param dims: select a subset of dimensions
""" """
def __init__(self, dims: Sequence[int] = None):
def __init__(self, dims: Sequence[int] | None = None):
x, y = load_iris(return_X_y=True) x, y = load_iris(return_X_y=True)
if dims is not None: if dims:
x = x[:, dims] x = x[:, dims]
super().__init__(x, y) super().__init__(x, y)
@ -57,20 +49,15 @@ class Blobs(NumpyDataset):
https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators. https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators.
""" """
def __init__(self,
def __init__( num_samples: int = 300,
self, num_features: int = 2,
num_samples: int = 300, seed: Union[None, int] = 0):
num_features: int = 2, x, y = make_blobs(num_samples,
seed: None | int = 0, num_features,
): centers=None,
x, y = make_blobs( random_state=seed,
num_samples, shuffle=False)
num_features,
centers=None,
random_state=seed,
shuffle=False,
)
super().__init__(x, y) super().__init__(x, y)
@ -82,34 +69,29 @@ class Random(NumpyDataset):
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy. Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
""" """
def __init__(self,
def __init__( num_samples: int = 300,
self, num_features: int = 2,
num_samples: int = 300, num_classes: int = 2,
num_features: int = 2, num_clusters: int = 2,
num_classes: int = 2, num_informative: Union[None, int] = None,
num_clusters: int = 2, separation: float = 1.0,
num_informative: None | int = None, seed: Union[None, int] = 0):
separation: float = 1.0,
seed: None | int = 0,
):
if not num_informative: if not num_informative:
import math import math
num_informative = math.ceil(math.log2(num_classes * num_clusters)) num_informative = math.ceil(math.log2(num_classes * num_clusters))
if num_features < num_informative: if num_features < num_informative:
warnings.warn("Generating more features than requested.") warnings.warn("Generating more features than requested.")
num_features = num_informative num_features = num_informative
x, y = make_classification( x, y = make_classification(num_samples,
num_samples, num_features,
num_features, n_informative=num_informative,
n_informative=num_informative, n_redundant=0,
n_redundant=0, n_classes=num_classes,
n_classes=num_classes, n_clusters_per_class=num_clusters,
n_clusters_per_class=num_clusters, class_sep=separation,
class_sep=separation, random_state=seed,
random_state=seed, shuffle=False)
shuffle=False,
)
super().__init__(x, y) super().__init__(x, y)
@ -122,21 +104,16 @@ class Circles(NumpyDataset):
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html
""" """
def __init__(self,
def __init__( num_samples: int = 300,
self, noise: float = 0.3,
num_samples: int = 300, factor: float = 0.8,
noise: float = 0.3, seed: Union[None, int] = 0):
factor: float = 0.8, x, y = make_circles(num_samples,
seed: None | int = 0, noise=noise,
): factor=factor,
x, y = make_circles( random_state=seed,
num_samples, shuffle=False)
noise=noise,
factor=factor,
random_state=seed,
shuffle=False,
)
super().__init__(x, y) super().__init__(x, y)
@ -149,17 +126,12 @@ class Moons(NumpyDataset):
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html
""" """
def __init__(self,
def __init__( num_samples: int = 300,
self, noise: float = 0.3,
num_samples: int = 300, seed: Union[None, int] = 0):
noise: float = 0.3, x, y = make_moons(num_samples,
seed: None | int = 0, noise=noise,
): random_state=seed,
x, y = make_moons( shuffle=False)
num_samples,
noise=noise,
random_state=seed,
shuffle=False,
)
super().__init__(x, y) super().__init__(x, y)

View File

@ -9,7 +9,6 @@ def make_spiral(num_samples=500, noise=0.3):
For use in Prototorch use `prototorch.datasets.Spiral` instead. For use in Prototorch use `prototorch.datasets.Spiral` instead.
""" """
def get_samples(n, delta_t): def get_samples(n, delta_t):
points = [] points = []
for i in range(n): for i in range(n):
@ -53,7 +52,6 @@ class Spiral(torch.utils.data.TensorDataset):
:param num_samples: number of random samples :param num_samples: number of random samples
:param noise: noise added to the spirals :param noise: noise added to the spirals
""" """
def __init__(self, num_samples: int = 500, noise: float = 0.3): def __init__(self, num_samples: int = 500, noise: float = 0.3):
x, y = make_spiral(num_samples, noise) x, y = make_spiral(num_samples, noise)
super().__init__(torch.Tensor(x), torch.LongTensor(y)) super().__init__(torch.Tensor(x), torch.LongTensor(y))

View File

@ -36,7 +36,6 @@ Description:
are determined by analytic chemistry. are determined by analytic chemistry.
""" """
import logging
import os import os
import numpy as np import numpy as np
@ -82,11 +81,13 @@ class Tecator(ProtoDataset):
if self._check_exists(): if self._check_exists():
return return
logging.debug("Making directories...") if self.verbose:
print("Making directories...")
os.makedirs(self.raw_folder, exist_ok=True) os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True) os.makedirs(self.processed_folder, exist_ok=True)
logging.debug("Downloading...") if self.verbose:
print("Downloading...")
for fileid, md5 in self._resources: for fileid, md5 in self._resources:
filename = "tecator.npz" filename = "tecator.npz"
download_file_from_google_drive(fileid, download_file_from_google_drive(fileid,
@ -94,7 +95,8 @@ class Tecator(ProtoDataset):
filename=filename, filename=filename,
md5=md5) md5=md5)
logging.debug("Processing...") if self.verbose:
print("Processing...")
with np.load(os.path.join(self.raw_folder, "tecator.npz"), with np.load(os.path.join(self.raw_folder, "tecator.npz"),
allow_pickle=False) as f: allow_pickle=False) as f:
x_train, y_train = f["x_train"], f["y_train"] x_train, y_train = f["x_train"], f["y_train"]
@ -115,4 +117,5 @@ class Tecator(ProtoDataset):
"wb") as f: "wb") as f:
torch.save(test_set, f) torch.save(test_set, f)
logging.debug("Done!") if self.verbose:
print("Done!")

View File

@ -13,7 +13,6 @@ def make_xor(num_samples=500):
class XOR(torch.utils.data.TensorDataset): class XOR(torch.utils.data.TensorDataset):
"""Exclusive-or (XOR) dataset for binary classification.""" """Exclusive-or (XOR) dataset for binary classification."""
def __init__(self, num_samples: int = 500): def __init__(self, num_samples: int = 500):
x, y = make_xor(num_samples) x, y = make_xor(num_samples)
super().__init__(x, y) super().__init__(x, y)

View File

@ -4,7 +4,6 @@ import torch
class LambdaLayer(torch.nn.Module): class LambdaLayer(torch.nn.Module):
def __init__(self, fn, name=None): def __init__(self, fn, name=None):
super().__init__() super().__init__()
self.fn = fn self.fn = fn
@ -18,7 +17,6 @@ class LambdaLayer(torch.nn.Module):
class LossLayer(torch.nn.modules.loss._Loss): class LossLayer(torch.nn.modules.loss._Loss):
def __init__(self, def __init__(self,
fn, fn,
name=None, name=None,

View File

@ -1,11 +1,6 @@
"""ProtoTorch utils module""" """ProtoFlow utils module"""
from .colors import ( from .colors import hex_to_rgb, rgb_to_hex
get_colors,
get_legend_handles,
hex_to_rgb,
rgb_to_hex,
)
from .utils import ( from .utils import (
mesh2d, mesh2d,
parse_data_arg, parse_data_arg,

View File

@ -1,13 +1,4 @@
"""ProtoTorch color utilities""" """ProtoFlow color utilities"""
import matplotlib.lines as mlines
import torch
from matplotlib import cm
from matplotlib.colors import (
Normalize,
to_hex,
to_rgb,
)
def hex_to_rgb(hex_values): def hex_to_rgb(hex_values):
@ -22,39 +13,3 @@ def rgb_to_hex(rgb_values):
for v in rgb_values: for v in rgb_values:
c = "%02x%02x%02x" % tuple(v) c = "%02x%02x%02x" % tuple(v)
yield c yield c
def get_colors(vmax, vmin=0, cmap="viridis"):
cmap = cm.get_cmap(cmap)
colornorm = Normalize(vmin=vmin, vmax=vmax)
colors = dict()
for c in range(vmin, vmax + 1):
colors[c] = to_hex(cmap(colornorm(c)))
return colors
def get_legend_handles(colors, labels, marker="dots", zero_indexed=False):
handles = list()
for color, label in zip(colors.values(), labels):
if marker == "dots":
handle = mlines.Line2D(
xdata=[],
ydata=[],
label=label,
color="white",
markerfacecolor=color,
marker="o",
markersize=10,
markeredgecolor="k",
)
else:
handle = mlines.Line2D(
xdata=[],
ydata=[],
label=label,
color=color,
marker="",
markersize=15,
)
handles.append(handle)
return handles

View File

@ -1,11 +1,10 @@
"""ProtoTorch utilities""" """ProtoFlow utilities"""
import warnings import warnings
from typing import ( from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
Optional,
Union, Union,
) )
@ -14,32 +13,6 @@ import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
def generate_mesh(
minima: torch.TensorType,
maxima: torch.TensorType,
border: float = 1.0,
resolution: int = 100,
device: Optional[torch.device] = None,
):
# Apply Border
ptp = maxima - minima
shift = border * ptp
minima -= shift
maxima += shift
# Generate Mesh
minima = minima.to(device).unsqueeze(1)
maxima = maxima.to(device).unsqueeze(1)
factors = torch.linspace(0, 1, resolution, device=device)
marginals = factors * maxima + ((1 - factors) * minima)
single_dimensions = torch.meshgrid(*marginals)
mesh_input = torch.stack([dim.ravel() for dim in single_dimensions], dim=1)
return mesh_input, single_dimensions
def mesh2d(x=None, border: float = 1.0, resolution: int = 100): def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
if x is not None: if x is not None:
x_shift = border * np.ptp(x[:, 0]) x_shift = border * np.ptp(x[:, 0])
@ -56,15 +29,14 @@ def mesh2d(x=None, border: float = 1.0, resolution: int = 100):
def distribution_from_list(list_dist: List[int], def distribution_from_list(list_dist: List[int],
clabels: Optional[Iterable[int]] = None): clabels: Iterable[int] = None):
clabels = clabels or list(range(len(list_dist))) clabels = clabels or list(range(len(list_dist)))
distribution = dict(zip(clabels, list_dist)) distribution = dict(zip(clabels, list_dist))
return distribution return distribution
def parse_distribution( def parse_distribution(user_distribution,
user_distribution, clabels: Iterable[int] = None) -> Dict[int, int]:
clabels: Optional[Iterable[int]] = None) -> Dict[int, int]:
"""Parse user-provided distribution. """Parse user-provided distribution.
Return a dictionary with integer keys that represent the class labels and Return a dictionary with integer keys that represent the class labels and

View File

@ -1,9 +1,8 @@
[pylint] [pylint]
disable = disable =
too-many-arguments, too-many-arguments,
too-few-public-methods, too-few-public-methods,
fixme, fixme,
[pycodestyle] [pycodestyle]
max-line-length = 79 max-line-length = 79
@ -13,4 +12,4 @@ multi_line_output = 3
include_trailing_comma = True include_trailing_comma = True
force_grid_wrap = 3 force_grid_wrap = 3
use_parentheses = True use_parentheses = True
line_length = 79 line_length = 79

View File

@ -15,22 +15,21 @@ from setuptools import find_packages, setup
PROJECT_URL = "https://github.com/si-cim/prototorch" PROJECT_URL = "https://github.com/si-cim/prototorch"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git" DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
with open("README.md", encoding="utf-8") as fh: with open("README.md", "r") as fh:
long_description = fh.read() long_description = fh.read()
INSTALL_REQUIRES = [ INSTALL_REQUIRES = [
"torch>=2.0.0", "torch>=1.3.1",
"torchvision", "torchvision>=0.7.1",
"numpy", "numpy>=1.9.1",
"scikit-learn", "sklearn",
"matplotlib",
] ]
DATASETS = [ DATASETS = [
"requests", "requests",
"tqdm", "tqdm",
] ]
DEV = [ DEV = [
"bump2version", "bumpversion",
"pre-commit", "pre-commit",
] ]
DOCS = [ DOCS = [
@ -41,17 +40,18 @@ DOCS = [
"sphinx-autodoc-typehints", "sphinx-autodoc-typehints",
] ]
EXAMPLES = [ EXAMPLES = [
"matplotlib",
"torchinfo", "torchinfo",
] ]
TESTS = [ TESTS = [
"flake8", "pytest-cov",
"pytest", "pytest",
] ]
ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
name="prototorch", name="prototorch",
version="0.7.6", version="0.7.1",
description="Highly extensible, GPU-supported " description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox " "Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.", "built using PyTorch and its nn API.",
@ -62,7 +62,7 @@ setup(
url=PROJECT_URL, url=PROJECT_URL,
download_url=DOWNLOAD_URL, download_url=DOWNLOAD_URL,
license="MIT", license="MIT",
python_requires=">=3.8", python_requires=">=3.6",
install_requires=INSTALL_REQUIRES, install_requires=INSTALL_REQUIRES,
extras_require={ extras_require={
"datasets": DATASETS, "datasets": DATASETS,
@ -85,10 +85,10 @@ setup(
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
], ],
packages=find_packages(), packages=find_packages(),
zip_safe=False, zip_safe=False,

View File

@ -245,20 +245,20 @@ def test_random_reasonings_init_channels_not_first():
# Transform initializers # Transform initializers
def test_eye_transform_init_square(): def test_eye_transform_init_square():
t = pt.initializers.EyeLinearTransformInitializer() t = pt.initializers.EyeTransformInitializer()
I = t.generate(3, 3) I = t.generate(3, 3)
assert torch.allclose(I, torch.eye(3)) assert torch.allclose(I, torch.eye(3))
def test_eye_transform_init_narrow(): def test_eye_transform_init_narrow():
t = pt.initializers.EyeLinearTransformInitializer() t = pt.initializers.EyeTransformInitializer()
actual = t.generate(3, 2) actual = t.generate(3, 2)
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]]) desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
assert torch.allclose(actual, desired) assert torch.allclose(actual, desired)
def test_eye_transform_init_wide(): def test_eye_transform_init_wide():
t = pt.initializers.EyeLinearTransformInitializer() t = pt.initializers.EyeTransformInitializer()
actual = t.generate(2, 3) actual = t.generate(2, 3)
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]]) desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
assert torch.allclose(actual, desired) assert torch.allclose(actual, desired)
@ -404,7 +404,6 @@ def test_glvq_loss_one_hot_unequal():
# Activations # Activations
class TestActivations(unittest.TestCase): class TestActivations(unittest.TestCase):
def setUp(self): def setUp(self):
self.flist = ["identity", "sigmoid_beta", "swish_beta"] self.flist = ["identity", "sigmoid_beta", "swish_beta"]
self.x = torch.randn(1024, 1) self.x = torch.randn(1024, 1)
@ -419,7 +418,6 @@ class TestActivations(unittest.TestCase):
self.assertTrue(iscallable) self.assertTrue(iscallable)
def test_callable_deserialization(self): def test_callable_deserialization(self):
def dummy(x, **kwargs): def dummy(x, **kwargs):
return x return x
@ -464,7 +462,6 @@ class TestActivations(unittest.TestCase):
# Competitions # Competitions
class TestCompetitions(unittest.TestCase): class TestCompetitions(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
@ -518,7 +515,6 @@ class TestCompetitions(unittest.TestCase):
# Pooling # Pooling
class TestPooling(unittest.TestCase): class TestPooling(unittest.TestCase):
def setUp(self): def setUp(self):
pass pass
@ -619,7 +615,6 @@ class TestPooling(unittest.TestCase):
# Distances # Distances
class TestDistances(unittest.TestCase): class TestDistances(unittest.TestCase):
def setUp(self): def setUp(self):
self.nx, self.mx = 32, 2048 self.nx, self.mx = 32, 2048
self.ny, self.my = 8, 2048 self.ny, self.my = 8, 2048

View File

@ -1,6 +1,7 @@
"""ProtoTorch datasets test suite""" """ProtoTorch datasets test suite"""
import os import os
import shutil
import unittest import unittest
import numpy as np import numpy as np
@ -11,7 +12,6 @@ from prototorch.datasets.abstract import Dataset, ProtoDataset
class TestAbstract(unittest.TestCase): class TestAbstract(unittest.TestCase):
def setUp(self): def setUp(self):
self.ds = Dataset("./artifacts") self.ds = Dataset("./artifacts")
@ -28,7 +28,6 @@ class TestAbstract(unittest.TestCase):
class TestProtoDataset(unittest.TestCase): class TestProtoDataset(unittest.TestCase):
def test_download(self): def test_download(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
_ = ProtoDataset("./artifacts", download=True) _ = ProtoDataset("./artifacts", download=True)
@ -39,7 +38,6 @@ class TestProtoDataset(unittest.TestCase):
class TestNumpyDataset(unittest.TestCase): class TestNumpyDataset(unittest.TestCase):
def test_list_init(self): def test_list_init(self):
ds = pt.datasets.NumpyDataset([1], [1]) ds = pt.datasets.NumpyDataset([1], [1])
self.assertEqual(len(ds), 1) self.assertEqual(len(ds), 1)
@ -52,7 +50,6 @@ class TestNumpyDataset(unittest.TestCase):
class TestCSVDataset(unittest.TestCase): class TestCSVDataset(unittest.TestCase):
def setUp(self): def setUp(self):
data = np.random.rand(100, 4) data = np.random.rand(100, 4)
targets = np.random.randint(2, size=(100, 1)) targets = np.random.randint(2, size=(100, 1))
@ -70,14 +67,12 @@ class TestCSVDataset(unittest.TestCase):
class TestSpiral(unittest.TestCase): class TestSpiral(unittest.TestCase):
def test_init(self): def test_init(self):
ds = pt.datasets.Spiral(num_samples=10) ds = pt.datasets.Spiral(num_samples=10)
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)
class TestIris(unittest.TestCase): class TestIris(unittest.TestCase):
def setUp(self): def setUp(self):
self.ds = pt.datasets.Iris() self.ds = pt.datasets.Iris()
@ -93,28 +88,24 @@ class TestIris(unittest.TestCase):
class TestBlobs(unittest.TestCase): class TestBlobs(unittest.TestCase):
def test_size(self): def test_size(self):
ds = pt.datasets.Blobs(num_samples=10) ds = pt.datasets.Blobs(num_samples=10)
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)
class TestRandom(unittest.TestCase): class TestRandom(unittest.TestCase):
def test_size(self): def test_size(self):
ds = pt.datasets.Random(num_samples=10) ds = pt.datasets.Random(num_samples=10)
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)
class TestCircles(unittest.TestCase): class TestCircles(unittest.TestCase):
def test_size(self): def test_size(self):
ds = pt.datasets.Circles(num_samples=10) ds = pt.datasets.Circles(num_samples=10)
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)
class TestMoons(unittest.TestCase): class TestMoons(unittest.TestCase):
def test_size(self): def test_size(self):
ds = pt.datasets.Moons(num_samples=10) ds = pt.datasets.Moons(num_samples=10)
self.assertEqual(len(ds), 10) self.assertEqual(len(ds), 10)