From 639198e77433ea432aec639e6007589d52d73039 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 17 May 2021 16:57:13 +0200 Subject: [PATCH] Update Iris dataset --- prototorch/datasets/iris.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/prototorch/datasets/iris.py b/prototorch/datasets/iris.py index 15bd396..aa16caa 100644 --- a/prototorch/datasets/iris.py +++ b/prototorch/datasets/iris.py @@ -10,6 +10,8 @@ from sklearn.datasets import load_iris class Iris(NumpyDataset): - def __init__(self): + def __init__(self, dims=None): x, y = load_iris(return_X_y=True) + if dims: + x = x[:, dims] super().__init__(x, y)