Add thin wrapper for the Iris dataset

This commit is contained in:
Jensun Ravichandran 2021-05-11 17:06:41 +02:00
parent ae6bc47f87
commit be21412f8a
3 changed files with 5 additions and 8 deletions

View File

@ -67,8 +67,9 @@ class LabeledComponents(Components):
*, *,
initialized_components=None): initialized_components=None):
if initialized_components is not None: if initialized_components is not None:
super().__init__(initialized_components=initialized_components[0]) components, component_labels = initialized_components
self._labels = initialized_components[1] super().__init__(initialized_components=components)
self._labels = component_labels
else: else:
self._initialize_labels(distribution) self._initialize_labels(distribution)
super().__init__(number_of_components=len(self._labels), super().__init__(number_of_components=len(self._labels),

View File

@ -1,11 +1,6 @@
"""ProtoTorch datasets.""" """ProtoTorch datasets."""
from .abstract import NumpyDataset from .abstract import NumpyDataset
from .iris import Iris
from .spiral import Spiral from .spiral import Spiral
from .tecator import Tecator from .tecator import Tecator
__all__ = [
"NumpyDataset",
"Spiral",
"Tecator",
]

View File

@ -23,6 +23,7 @@ INSTALL_REQUIRES = [
] ]
DATASETS = [ DATASETS = [
"requests", "requests",
"sklearn",
"tqdm", "tqdm",
] ]
DEV = ["bumpversion"] DEV = ["bumpversion"]