Improve dataset documentation.

This commit is contained in:
Alexander Engelsberger 2021-05-18 18:54:43 +02:00
parent b935e9caf3
commit b2e1df7308
8 changed files with 81 additions and 8 deletions

View File

@ -5,9 +5,20 @@ ProtoTorch API Reference
Datasets
--------------------------------------
Common Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: prototorch.datasets
:members:
:undoc-members:
Abstract Datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Abstract Datasets are used to build your own datasets.
.. autoclass:: prototorch.datasets.abstract.NumpyDataset
:members:
Functions
--------------------------------------

View File

@ -46,6 +46,7 @@ extensions = [
"sphinx.ext.viewcode",
"sphinx_rtd_theme",
"sphinxcontrib.katex",
'sphinx_autodoc_typehints',
]
# katex_prerender = True

View File

@ -4,3 +4,5 @@ from .abstract import NumpyDataset
from .iris import Iris
from .spiral import Spiral
from .tecator import Tecator
__all__ = ['Iris', 'Spiral', 'Tecator']

View File

@ -5,12 +5,35 @@ URL:
"""
from typing import Sequence
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import load_iris
class Iris(NumpyDataset):
def __init__(self, dims=None):
"""
Iris Dataset by Ronald Fisher introduced in 1936.
The dataset contains four measurements from flowers of three species of iris.
.. list-table:: Iris
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 4
- 2
- 150
- 0
- 0
:param dims: select a subset of dimensions
"""
def __init__(self, dims: Sequence[int] = None):
x, y = load_iris(return_X_y=True)
if dims:
x = x[:, dims]

View File

@ -27,7 +27,27 @@ def make_spiral(n_samples=500, noise=0.3):
class Spiral(torch.utils.data.TensorDataset):
"""Spiral dataset for binary classification."""
def __init__(self, n_samples=500, noise=0.3):
"""Spiral dataset for binary classification.
This datasets consists of two spirals of two different classes.
.. list-table:: Spiral
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 2
- 2
- n_samples
- 0
- 0
:param n_samples: number of random samples
:param noise: noise added to the spirals
"""
def __init__(self, n_samples: int = 500, noise: float = 0.3):
x, y = make_spiral(n_samples, noise)
super().__init__(torch.Tensor(x), torch.LongTensor(y))

View File

@ -40,15 +40,29 @@ import os
import numpy as np
import torch
from torchvision.datasets.utils import download_file_from_google_drive
from prototorch.datasets.abstract import ProtoDataset
from torchvision.datasets.utils import download_file_from_google_drive
class Tecator(ProtoDataset):
"""
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__
for classification.
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__ for classification.
The dataset contains wavelength measurements of meat.
.. list-table:: Tecator
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 100
- 2
- 129
- 43
- 43
"""
_resources = [

View File

@ -55,4 +55,5 @@ def lvq21_loss(distances, target_labels, prototype_labels):
"""
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = dp - dm
return mu

View File

@ -32,6 +32,7 @@ DOCS = [
"sphinx",
"sphinx_rtd_theme",
"sphinxcontrib-katex",
"sphinx-autodoc-typehints",
]
EXAMPLES = [
"sklearn",