Improve dataset documentation.

This commit is contained in:
Alexander Engelsberger 2021-05-18 18:54:43 +02:00
parent b935e9caf3
commit b2e1df7308
8 changed files with 81 additions and 8 deletions

View File

@ -5,9 +5,20 @@ ProtoTorch API Reference
Datasets Datasets
-------------------------------------- --------------------------------------
Common Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: prototorch.datasets .. automodule:: prototorch.datasets
:members: :members:
:undoc-members:
Abstract Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Abstract Datasets are used to build your own datasets.
.. autoclass:: prototorch.datasets.abstract.NumpyDataset
:members:
Functions Functions
-------------------------------------- --------------------------------------

View File

@ -46,6 +46,7 @@ extensions = [
"sphinx.ext.viewcode", "sphinx.ext.viewcode",
"sphinx_rtd_theme", "sphinx_rtd_theme",
"sphinxcontrib.katex", "sphinxcontrib.katex",
'sphinx_autodoc_typehints',
] ]
# katex_prerender = True # katex_prerender = True

View File

@ -4,3 +4,5 @@ from .abstract import NumpyDataset
from .iris import Iris from .iris import Iris
from .spiral import Spiral from .spiral import Spiral
from .tecator import Tecator from .tecator import Tecator
__all__ = ['Iris', 'Spiral', 'Tecator']

View File

@ -5,12 +5,35 @@ URL:
""" """
from typing import Sequence
from prototorch.datasets.abstract import NumpyDataset from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
class Iris(NumpyDataset): class Iris(NumpyDataset):
def __init__(self, dims=None): """
Iris Dataset by Ronald Fisher introduced in 1936.
The dataset contains four measurements from flowers of three species of iris.
.. list-table:: Iris
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 4
- 2
- 150
- 0
- 0
:param dims: select a subset of dimensions
"""
def __init__(self, dims: Sequence[int] = None):
x, y = load_iris(return_X_y=True) x, y = load_iris(return_X_y=True)
if dims: if dims:
x = x[:, dims] x = x[:, dims]

View File

@ -27,7 +27,27 @@ def make_spiral(n_samples=500, noise=0.3):
class Spiral(torch.utils.data.TensorDataset): class Spiral(torch.utils.data.TensorDataset):
"""Spiral dataset for binary classification.""" """Spiral dataset for binary classification.
def __init__(self, n_samples=500, noise=0.3):
This datasets consists of two spirals of two different classes.
.. list-table:: Spiral
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 2
- 2
- n_samples
- 0
- 0
:param n_samples: number of random samples
:param noise: noise added to the spirals
"""
def __init__(self, n_samples: int = 500, noise: float = 0.3):
x, y = make_spiral(n_samples, noise) x, y = make_spiral(n_samples, noise)
super().__init__(torch.Tensor(x), torch.LongTensor(y)) super().__init__(torch.Tensor(x), torch.LongTensor(y))

View File

@ -40,15 +40,29 @@ import os
import numpy as np import numpy as np
import torch import torch
from torchvision.datasets.utils import download_file_from_google_drive
from prototorch.datasets.abstract import ProtoDataset from prototorch.datasets.abstract import ProtoDataset
from torchvision.datasets.utils import download_file_from_google_drive
class Tecator(ProtoDataset): class Tecator(ProtoDataset):
""" """
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__ `Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__ for classification.
for classification.
The dataset contains wavelength measurements of meat.
.. list-table:: Tecator
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 100
- 2
- 129
- 43
- 43
""" """
_resources = [ _resources = [

View File

@ -55,4 +55,5 @@ def lvq21_loss(distances, target_labels, prototype_labels):
""" """
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = dp - dm mu = dp - dm
return mu return mu

View File

@ -32,6 +32,7 @@ DOCS = [
"sphinx", "sphinx",
"sphinx_rtd_theme", "sphinx_rtd_theme",
"sphinxcontrib-katex", "sphinxcontrib-katex",
"sphinx-autodoc-typehints",
] ]
EXAMPLES = [ EXAMPLES = [
"sklearn", "sklearn",