Improve dataset documentation.
This commit is contained in:
parent
b935e9caf3
commit
b2e1df7308
@ -5,9 +5,20 @@ ProtoTorch API Reference
|
||||
|
||||
Datasets
|
||||
--------------------------------------
|
||||
|
||||
Common Datasets
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
.. automodule:: prototorch.datasets
|
||||
:members:
|
||||
:undoc-members:
|
||||
|
||||
|
||||
Abstract Datasets
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Abstract Datasets are used to build your own datasets.
|
||||
|
||||
.. autoclass:: prototorch.datasets.abstract.NumpyDataset
|
||||
:members:
|
||||
|
||||
Functions
|
||||
--------------------------------------
|
||||
|
@ -46,6 +46,7 @@ extensions = [
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib.katex",
|
||||
'sphinx_autodoc_typehints',
|
||||
]
|
||||
|
||||
# katex_prerender = True
|
||||
|
@ -4,3 +4,5 @@ from .abstract import NumpyDataset
|
||||
from .iris import Iris
|
||||
from .spiral import Spiral
|
||||
from .tecator import Tecator
|
||||
|
||||
__all__ = ['Iris', 'Spiral', 'Tecator']
|
||||
|
@ -5,12 +5,35 @@ URL:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from prototorch.datasets.abstract import NumpyDataset
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
|
||||
class Iris(NumpyDataset):
|
||||
def __init__(self, dims=None):
|
||||
"""
|
||||
Iris Dataset by Ronald Fisher introduced in 1936.
|
||||
|
||||
The dataset contains four measurements from flowers of three species of iris.
|
||||
|
||||
.. list-table:: Iris
|
||||
:header-rows: 1
|
||||
|
||||
* - dimensions
|
||||
- classes
|
||||
- training size
|
||||
- validation size
|
||||
- test size
|
||||
* - 4
|
||||
- 2
|
||||
- 150
|
||||
- 0
|
||||
- 0
|
||||
|
||||
:param dims: select a subset of dimensions
|
||||
"""
|
||||
def __init__(self, dims: Sequence[int] = None):
|
||||
x, y = load_iris(return_X_y=True)
|
||||
if dims:
|
||||
x = x[:, dims]
|
||||
|
@ -27,7 +27,27 @@ def make_spiral(n_samples=500, noise=0.3):
|
||||
|
||||
|
||||
class Spiral(torch.utils.data.TensorDataset):
|
||||
"""Spiral dataset for binary classification."""
|
||||
def __init__(self, n_samples=500, noise=0.3):
|
||||
"""Spiral dataset for binary classification.
|
||||
|
||||
This datasets consists of two spirals of two different classes.
|
||||
|
||||
.. list-table:: Spiral
|
||||
:header-rows: 1
|
||||
|
||||
* - dimensions
|
||||
- classes
|
||||
- training size
|
||||
- validation size
|
||||
- test size
|
||||
* - 2
|
||||
- 2
|
||||
- n_samples
|
||||
- 0
|
||||
- 0
|
||||
|
||||
:param n_samples: number of random samples
|
||||
:param noise: noise added to the spirals
|
||||
"""
|
||||
def __init__(self, n_samples: int = 500, noise: float = 0.3):
|
||||
x, y = make_spiral(n_samples, noise)
|
||||
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
||||
|
@ -40,15 +40,29 @@ import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.datasets.utils import download_file_from_google_drive
|
||||
|
||||
from prototorch.datasets.abstract import ProtoDataset
|
||||
from torchvision.datasets.utils import download_file_from_google_drive
|
||||
|
||||
|
||||
class Tecator(ProtoDataset):
|
||||
"""
|
||||
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__
|
||||
for classification.
|
||||
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__ for classification.
|
||||
|
||||
The dataset contains wavelength measurements of meat.
|
||||
|
||||
.. list-table:: Tecator
|
||||
:header-rows: 1
|
||||
|
||||
* - dimensions
|
||||
- classes
|
||||
- training size
|
||||
- validation size
|
||||
- test size
|
||||
* - 100
|
||||
- 2
|
||||
- 129
|
||||
- 43
|
||||
- 43
|
||||
"""
|
||||
|
||||
_resources = [
|
||||
|
@ -55,4 +55,5 @@ def lvq21_loss(distances, target_labels, prototype_labels):
|
||||
"""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = dp - dm
|
||||
|
||||
return mu
|
||||
|
Loading…
Reference in New Issue
Block a user