Files
prototorch/prototorch/datasets/sklearn.py
2022-05-17 11:56:18 +02:00

166 lines
4.0 KiB
Python

"""Thin wrappers for a few scikit-learn datasets.
URL:
https://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets
"""
from __future__ import annotations
import warnings
from typing import Sequence
from sklearn.datasets import (
load_iris,
make_blobs,
make_circles,
make_classification,
make_moons,
)
from prototorch.datasets.abstract import NumpyDataset
class Iris(NumpyDataset):
"""Iris Dataset by Ronald Fisher introduced in 1936.
The dataset contains four measurements from flowers of three species of iris.
.. list-table:: Iris
:header-rows: 1
* - dimensions
- classes
- training size
- validation size
- test size
* - 4
- 3
- 150
- 0
- 0
:param dims: select a subset of dimensions
"""
def __init__(self, dims: Sequence[int] | None = None):
x, y = load_iris(return_X_y=True)
if dims is not None:
x = x[:, dims]
super().__init__(x, y)
class Blobs(NumpyDataset):
"""Generate isotropic Gaussian blobs for clustering.
Read more at
https://scikit-learn.org/stable/datasets/sample_generators.html#sample-generators.
"""
def __init__(
self,
num_samples: int = 300,
num_features: int = 2,
seed: None | int = 0,
):
x, y = make_blobs(
num_samples,
num_features,
centers=None,
random_state=seed,
shuffle=False,
)
super().__init__(x, y)
class Random(NumpyDataset):
"""Generate a random n-class classification problem.
Read more at
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_classification.html.
Note: n_classes * n_clusters_per_class <= 2**n_informative must satisfy.
"""
def __init__(
self,
num_samples: int = 300,
num_features: int = 2,
num_classes: int = 2,
num_clusters: int = 2,
num_informative: None | int = None,
separation: float = 1.0,
seed: None | int = 0,
):
if not num_informative:
import math
num_informative = math.ceil(math.log2(num_classes * num_clusters))
if num_features < num_informative:
warnings.warn("Generating more features than requested.")
num_features = num_informative
x, y = make_classification(
num_samples,
num_features,
n_informative=num_informative,
n_redundant=0,
n_classes=num_classes,
n_clusters_per_class=num_clusters,
class_sep=separation,
random_state=seed,
shuffle=False,
)
super().__init__(x, y)
class Circles(NumpyDataset):
"""Make a large circle containing a smaller circle in 2D.
A simple toy dataset to visualize clustering and classification algorithms.
Read more at
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_circles.html
"""
def __init__(
self,
num_samples: int = 300,
noise: float = 0.3,
factor: float = 0.8,
seed: None | int = 0,
):
x, y = make_circles(
num_samples,
noise=noise,
factor=factor,
random_state=seed,
shuffle=False,
)
super().__init__(x, y)
class Moons(NumpyDataset):
"""Make two interleaving half circles.
A simple toy dataset to visualize clustering and classification algorithms.
Read more at
https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_moons.html
"""
def __init__(
self,
num_samples: int = 300,
noise: float = 0.3,
seed: None | int = 0,
):
x, y = make_moons(
num_samples,
noise=noise,
random_state=seed,
shuffle=False,
)
super().__init__(x, y)