From be21412f8ae441720f5af464b0347a8a15bb7e12 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 11 May 2021 17:06:41 +0200 Subject: [PATCH] Add thin wrapper for the Iris dataset --- prototorch/components/components.py | 5 +++-- prototorch/datasets/__init__.py | 7 +------ setup.py | 1 + 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/prototorch/components/components.py b/prototorch/components/components.py index 10a89d0..589079c 100644 --- a/prototorch/components/components.py +++ b/prototorch/components/components.py @@ -67,8 +67,9 @@ class LabeledComponents(Components): *, initialized_components=None): if initialized_components is not None: - super().__init__(initialized_components=initialized_components[0]) - self._labels = initialized_components[1] + components, component_labels = initialized_components + super().__init__(initialized_components=components) + self._labels = component_labels else: self._initialize_labels(distribution) super().__init__(number_of_components=len(self._labels), diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index 051d850..e4fa517 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -1,11 +1,6 @@ """ProtoTorch datasets.""" from .abstract import NumpyDataset +from .iris import Iris from .spiral import Spiral from .tecator import Tecator - -__all__ = [ - "NumpyDataset", - "Spiral", - "Tecator", -] diff --git a/setup.py b/setup.py index 6c7b9a6..872bba1 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ INSTALL_REQUIRES = [ ] DATASETS = [ "requests", + "sklearn", "tqdm", ] DEV = ["bumpversion"]