[FEATURE] Add wrappers for more sklearn datasets

This commit is contained in:
Jensun Ravichandran 2021-06-01 23:33:51 +02:00
parent d8a0b2dfcc
commit 2eb7b05653
3 changed files with 138 additions and 43 deletions

View File

@ -1,8 +1,6 @@
"""ProtoTorch datasets.""" """ProtoTorch datasets."""
from .abstract import NumpyDataset from .abstract import NumpyDataset
from .iris import Iris from .sklearn import Blobs, Circles, Iris, Moons, Random
from .spiral import Spiral from .spiral import Spiral
from .tecator import Tecator from .tecator import Tecator
__all__ = ['Iris', 'Spiral', 'Tecator']

View File

@ -1,40 +0,0 @@
"""Thin wrapper for the Iris classification dataset from sklearn.
URL:
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html
"""
from typing import Sequence
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import load_iris
class Iris(NumpyDataset):
"""
Iris Dataset by Ronald Fisher introduced in 1936.
The dataset contains four measurements from flowers of three species of iris.
.. list-table:: Iris
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 4
- 3
- 150
- 0
- 0
:param dims: select a subset of dimensions
"""
def __init__(self, dims: Sequence[int] = None):
x, y = load_iris(return_X_y=True)
if dims:
x = x[:, dims]
super().__init__(x, y)

View File

@ -0,0 +1,137 @@
"""Thin wrappers for a few scikit-learn datasets.
URL:
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets
"""
import warnings
from typing import Sequence, Union
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import (load_iris, make_blobs, make_circles,
make_classification, make_moons)
class Iris(NumpyDataset):
"""Iris Dataset by Ronald Fisher introduced in 1936.
The dataset contains four measurements from flowers of three species of iris.
.. list-table:: Iris
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 4
- 3
- 150
- 0
- 0
:param dims: select a subset of dimensions
"""
def __init__(self, dims: Sequence[int] = None):
x, y = load_iris(return_X_y=True)
if dims:
x = x[:, dims]
super().__init__(x, y)
class Blobs(NumpyDataset):
"""Generate isotropic Gaussian blobs for clustering.
Read more at
https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators.
"""
def __init__(self,
num_samples: int = 300,
num_features: int = 2,
seed: Union[None, int] = 0):
x, y = make_blobs(num_samples,
num_features,
centers=None,
random_state=seed,
shuffle=False)
super().__init__(x, y)
class Random(NumpyDataset):
"""Generate a random n-class classification problem.
Read more at
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html.
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
"""
def __init__(self,
num_samples: int = 300,
num_features: int = 2,
num_classes: int = 2,
num_clusters: int = 2,
num_informative: Union[None, int] = None,
separation: float = 1.0,
seed: Union[None, int] = 0):
if not num_informative:
import math
num_informative = math.ceil(math.log2(num_classes * num_clusters))
if num_features < num_informative:
warnings.warn("Generating more features than requested.")
num_features = num_informative
x, y = make_classification(num_samples,
num_features,
n_informative=num_informative,
n_redundant=0,
n_classes=num_classes,
n_clusters_per_class=num_clusters,
class_sep=separation,
random_state=seed,
shuffle=False)
super().__init__(x, y)
class Circles(NumpyDataset):
"""Make a large circle containing a smaller circle in 2D.
A simple toy dataset to visualize clustering and classification algorithms.
Read more at
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html
"""
def __init__(self,
num_samples: int = 300,
noise: float = 0.3,
factor: float = 0.8,
seed: Union[None, int] = 0):
x, y = make_circles(num_samples,
noise=noise,
factor=factor,
random_state=seed,
shuffle=False)
super().__init__(x, y)
class Moons(NumpyDataset):
"""Make two interleaving half circles.
A simple toy dataset to visualize clustering and classification algorithms.
Read more at
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html
"""
def __init__(self,
num_samples: int = 300,
noise: float = 0.3,
seed: Union[None, int] = 0):
x, y = make_moons(num_samples,
noise=noise,
random_state=seed,
shuffle=False)
super().__init__(x, y)