From fdb9a7c66d18979ba4c0fe5001a95857e476afe5 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sun, 4 Jul 2021 16:30:01 +0200 Subject: [PATCH] feat: add `CSVDataset` --- prototorch/datasets/__init__.py | 2 +- prototorch/datasets/abstract.py | 14 ++++++++++++++ tests/test_datasets.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index 096fc6f..1fd485f 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -1,6 +1,6 @@ """ProtoTorch datasets""" -from .abstract import NumpyDataset +from .abstract import CSVDataset, NumpyDataset from .sklearn import ( Blobs, Circles, diff --git a/prototorch/datasets/abstract.py b/prototorch/datasets/abstract.py index dac8f8c..f4b6660 100644 --- a/prototorch/datasets/abstract.py +++ b/prototorch/datasets/abstract.py @@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py import os +import numpy as np import torch @@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset): self.targets = torch.LongTensor(targets) tensors = [self.data, self.targets] super().__init__(*tensors) + + +class CSVDataset(NumpyDataset): + """Create a Dataset from a CSV file.""" + def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0): + raw = np.genfromtxt( + filepath, + delimiter=delimiter, + skip_header=skip_header, + ) + data = np.delete(raw, 1, target_col) + targets = raw[:, target_col] + super().__init__(data, targets) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ea058de..66fd11e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -49,6 +49,21 @@ class TestNumpyDataset(unittest.TestCase): self.assertEqual(len(ds), 3) +class TestCSVDataset(unittest.TestCase): + def setUp(self): + data = np.random.rand(100, 4) + targets = np.random.randint(2, size=(100, 1)) + arr = np.hstack([data, targets]) + np.savetxt("./artifacts/test.csv", arr, delimiter=",") + + def test_len(self): + ds = pt.datasets.CSVDataset("./artifacts/test.csv") + self.assertEqual(len(ds), 100) + + def tearDown(self): + os.remove("./artifacts/test.csv") + + class TestSpiral(unittest.TestCase): def test_init(self): ds = pt.datasets.Spiral(num_samples=10)