Improve dataset documentation.
This commit is contained in:
parent
b935e9caf3
commit
b2e1df7308
@ -5,9 +5,20 @@ ProtoTorch API Reference
|
|||||||
|
|
||||||
Datasets
|
Datasets
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
|
|
||||||
|
Common Datasets
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
.. automodule:: prototorch.datasets
|
.. automodule:: prototorch.datasets
|
||||||
:members:
|
:members:
|
||||||
:undoc-members:
|
|
||||||
|
|
||||||
|
Abstract Datasets
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Abstract Datasets are used to build your own datasets.
|
||||||
|
|
||||||
|
.. autoclass:: prototorch.datasets.abstract.NumpyDataset
|
||||||
|
:members:
|
||||||
|
|
||||||
Functions
|
Functions
|
||||||
--------------------------------------
|
--------------------------------------
|
||||||
|
@ -46,6 +46,7 @@ extensions = [
|
|||||||
"sphinx.ext.viewcode",
|
"sphinx.ext.viewcode",
|
||||||
"sphinx_rtd_theme",
|
"sphinx_rtd_theme",
|
||||||
"sphinxcontrib.katex",
|
"sphinxcontrib.katex",
|
||||||
|
'sphinx_autodoc_typehints',
|
||||||
]
|
]
|
||||||
|
|
||||||
# katex_prerender = True
|
# katex_prerender = True
|
||||||
|
@ -4,3 +4,5 @@ from .abstract import NumpyDataset
|
|||||||
from .iris import Iris
|
from .iris import Iris
|
||||||
from .spiral import Spiral
|
from .spiral import Spiral
|
||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
|
||||||
|
__all__ = ['Iris', 'Spiral', 'Tecator']
|
||||||
|
@ -5,12 +5,35 @@ URL:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
from prototorch.datasets.abstract import NumpyDataset
|
from prototorch.datasets.abstract import NumpyDataset
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
|
|
||||||
|
|
||||||
class Iris(NumpyDataset):
|
class Iris(NumpyDataset):
|
||||||
def __init__(self, dims=None):
|
"""
|
||||||
|
Iris Dataset by Ronald Fisher introduced in 1936.
|
||||||
|
|
||||||
|
The dataset contains four measurements from flowers of three species of iris.
|
||||||
|
|
||||||
|
.. list-table:: Iris
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - dimensions
|
||||||
|
- classes
|
||||||
|
- training size
|
||||||
|
- validation size
|
||||||
|
- test size
|
||||||
|
* - 4
|
||||||
|
- 2
|
||||||
|
- 150
|
||||||
|
- 0
|
||||||
|
- 0
|
||||||
|
|
||||||
|
:param dims: select a subset of dimensions
|
||||||
|
"""
|
||||||
|
def __init__(self, dims: Sequence[int] = None):
|
||||||
x, y = load_iris(return_X_y=True)
|
x, y = load_iris(return_X_y=True)
|
||||||
if dims:
|
if dims:
|
||||||
x = x[:, dims]
|
x = x[:, dims]
|
||||||
|
@ -27,7 +27,27 @@ def make_spiral(n_samples=500, noise=0.3):
|
|||||||
|
|
||||||
|
|
||||||
class Spiral(torch.utils.data.TensorDataset):
|
class Spiral(torch.utils.data.TensorDataset):
|
||||||
"""Spiral dataset for binary classification."""
|
"""Spiral dataset for binary classification.
|
||||||
def __init__(self, n_samples=500, noise=0.3):
|
|
||||||
|
This datasets consists of two spirals of two different classes.
|
||||||
|
|
||||||
|
.. list-table:: Spiral
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - dimensions
|
||||||
|
- classes
|
||||||
|
- training size
|
||||||
|
- validation size
|
||||||
|
- test size
|
||||||
|
* - 2
|
||||||
|
- 2
|
||||||
|
- n_samples
|
||||||
|
- 0
|
||||||
|
- 0
|
||||||
|
|
||||||
|
:param n_samples: number of random samples
|
||||||
|
:param noise: noise added to the spirals
|
||||||
|
"""
|
||||||
|
def __init__(self, n_samples: int = 500, noise: float = 0.3):
|
||||||
x, y = make_spiral(n_samples, noise)
|
x, y = make_spiral(n_samples, noise)
|
||||||
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
super().__init__(torch.Tensor(x), torch.LongTensor(y))
|
||||||
|
@ -40,15 +40,29 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torchvision.datasets.utils import download_file_from_google_drive
|
|
||||||
|
|
||||||
from prototorch.datasets.abstract import ProtoDataset
|
from prototorch.datasets.abstract import ProtoDataset
|
||||||
|
from torchvision.datasets.utils import download_file_from_google_drive
|
||||||
|
|
||||||
|
|
||||||
class Tecator(ProtoDataset):
|
class Tecator(ProtoDataset):
|
||||||
"""
|
"""
|
||||||
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__
|
`Tecator Dataset <http://lib.stat.cmu.edu/datasets/tecator>`__ for classification.
|
||||||
for classification.
|
|
||||||
|
The dataset contains wavelength measurements of meat.
|
||||||
|
|
||||||
|
.. list-table:: Tecator
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - dimensions
|
||||||
|
- classes
|
||||||
|
- training size
|
||||||
|
- validation size
|
||||||
|
- test size
|
||||||
|
* - 100
|
||||||
|
- 2
|
||||||
|
- 129
|
||||||
|
- 43
|
||||||
|
- 43
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_resources = [
|
_resources = [
|
||||||
|
@ -55,4 +55,5 @@ def lvq21_loss(distances, target_labels, prototype_labels):
|
|||||||
"""
|
"""
|
||||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||||
mu = dp - dm
|
mu = dp - dm
|
||||||
|
|
||||||
return mu
|
return mu
|
||||||
|
Loading…
Reference in New Issue
Block a user