[FEATURE] Add wrappers for more sklearn datasets
This commit is contained in:
parent
d8a0b2dfcc
commit
2eb7b05653
@ -1,8 +1,6 @@
|
|||||||
"""ProtoTorch datasets."""
|
"""ProtoTorch datasets."""
|
||||||
|
|
||||||
from .abstract import NumpyDataset
|
from .abstract import NumpyDataset
|
||||||
from .iris import Iris
|
from .sklearn import Blobs, Circles, Iris, Moons, Random
|
||||||
from .spiral import Spiral
|
from .spiral import Spiral
|
||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
|
||||||
__all__ = ['Iris', 'Spiral', 'Tecator']
|
|
||||||
|
@ -1,40 +0,0 @@
|
|||||||
"""Thin wrapper for the Iris classification dataset from sklearn.
|
|
||||||
|
|
||||||
URL:
|
|
||||||
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Sequence
|
|
||||||
|
|
||||||
from prototorch.datasets.abstract import NumpyDataset
|
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
|
|
||||||
|
|
||||||
class Iris(NumpyDataset):
|
|
||||||
"""
|
|
||||||
Iris Dataset by Ronald Fisher introduced in 1936.
|
|
||||||
|
|
||||||
The dataset contains four measurements from flowers of three species of iris.
|
|
||||||
|
|
||||||
.. list-table:: Iris
|
|
||||||
:header-rows: 1
|
|
||||||
|
|
||||||
* - dimensions
|
|
||||||
- classes
|
|
||||||
- training size
|
|
||||||
- validation size
|
|
||||||
- test size
|
|
||||||
* - 4
|
|
||||||
- 3
|
|
||||||
- 150
|
|
||||||
- 0
|
|
||||||
- 0
|
|
||||||
|
|
||||||
:param dims: select a subset of dimensions
|
|
||||||
"""
|
|
||||||
def __init__(self, dims: Sequence[int] = None):
|
|
||||||
x, y = load_iris(return_X_y=True)
|
|
||||||
if dims:
|
|
||||||
x = x[:, dims]
|
|
||||||
super().__init__(x, y)
|
|
137
prototorch/datasets/sklearn.py
Normal file
137
prototorch/datasets/sklearn.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
"""Thin wrappers for a few scikit-learn datasets.
|
||||||
|
|
||||||
|
URL:
|
||||||
|
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from prototorch.datasets.abstract import NumpyDataset
|
||||||
|
|
||||||
|
from sklearn.datasets import (load_iris, make_blobs, make_circles,
|
||||||
|
make_classification, make_moons)
|
||||||
|
|
||||||
|
|
||||||
|
class Iris(NumpyDataset):
|
||||||
|
"""Iris Dataset by Ronald Fisher introduced in 1936.
|
||||||
|
|
||||||
|
The dataset contains four measurements from flowers of three species of iris.
|
||||||
|
|
||||||
|
.. list-table:: Iris
|
||||||
|
:header-rows: 1
|
||||||
|
|
||||||
|
* - dimensions
|
||||||
|
- classes
|
||||||
|
- training size
|
||||||
|
- validation size
|
||||||
|
- test size
|
||||||
|
* - 4
|
||||||
|
- 3
|
||||||
|
- 150
|
||||||
|
- 0
|
||||||
|
- 0
|
||||||
|
|
||||||
|
:param dims: select a subset of dimensions
|
||||||
|
"""
|
||||||
|
def __init__(self, dims: Sequence[int] = None):
|
||||||
|
x, y = load_iris(return_X_y=True)
|
||||||
|
if dims:
|
||||||
|
x = x[:, dims]
|
||||||
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Blobs(NumpyDataset):
|
||||||
|
"""Generate isotropic Gaussian blobs for clustering.
|
||||||
|
|
||||||
|
Read more at
|
||||||
|
https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_samples: int = 300,
|
||||||
|
num_features: int = 2,
|
||||||
|
seed: Union[None, int] = 0):
|
||||||
|
x, y = make_blobs(num_samples,
|
||||||
|
num_features,
|
||||||
|
centers=None,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False)
|
||||||
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Random(NumpyDataset):
|
||||||
|
"""Generate a random n-class classification problem.
|
||||||
|
|
||||||
|
Read more at
|
||||||
|
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html.
|
||||||
|
|
||||||
|
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_samples: int = 300,
|
||||||
|
num_features: int = 2,
|
||||||
|
num_classes: int = 2,
|
||||||
|
num_clusters: int = 2,
|
||||||
|
num_informative: Union[None, int] = None,
|
||||||
|
separation: float = 1.0,
|
||||||
|
seed: Union[None, int] = 0):
|
||||||
|
if not num_informative:
|
||||||
|
import math
|
||||||
|
num_informative = math.ceil(math.log2(num_classes * num_clusters))
|
||||||
|
if num_features < num_informative:
|
||||||
|
warnings.warn("Generating more features than requested.")
|
||||||
|
num_features = num_informative
|
||||||
|
x, y = make_classification(num_samples,
|
||||||
|
num_features,
|
||||||
|
n_informative=num_informative,
|
||||||
|
n_redundant=0,
|
||||||
|
n_classes=num_classes,
|
||||||
|
n_clusters_per_class=num_clusters,
|
||||||
|
class_sep=separation,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False)
|
||||||
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Circles(NumpyDataset):
|
||||||
|
"""Make a large circle containing a smaller circle in 2D.
|
||||||
|
|
||||||
|
A simple toy dataset to visualize clustering and classification algorithms.
|
||||||
|
|
||||||
|
Read more at
|
||||||
|
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_samples: int = 300,
|
||||||
|
noise: float = 0.3,
|
||||||
|
factor: float = 0.8,
|
||||||
|
seed: Union[None, int] = 0):
|
||||||
|
x, y = make_circles(num_samples,
|
||||||
|
noise=noise,
|
||||||
|
factor=factor,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False)
|
||||||
|
super().__init__(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
class Moons(NumpyDataset):
|
||||||
|
"""Make two interleaving half circles.
|
||||||
|
|
||||||
|
A simple toy dataset to visualize clustering and classification algorithms.
|
||||||
|
|
||||||
|
Read more at
|
||||||
|
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_samples: int = 300,
|
||||||
|
noise: float = 0.3,
|
||||||
|
seed: Union[None, int] = 0):
|
||||||
|
x, y = make_moons(num_samples,
|
||||||
|
noise=noise,
|
||||||
|
random_state=seed,
|
||||||
|
shuffle=False)
|
||||||
|
super().__init__(x, y)
|
Loading…
Reference in New Issue
Block a user