diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index 096fc6f..1fd485f 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -1,6 +1,6 @@ """ProtoTorch datasets""" -from .abstract import NumpyDataset +from .abstract import CSVDataset, NumpyDataset from .sklearn import ( Blobs, Circles, diff --git a/prototorch/datasets/abstract.py b/prototorch/datasets/abstract.py index dac8f8c..f4b6660 100644 --- a/prototorch/datasets/abstract.py +++ b/prototorch/datasets/abstract.py @@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py import os +import numpy as np import torch @@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset): self.targets = torch.LongTensor(targets) tensors = [self.data, self.targets] super().__init__(*tensors) + + +class CSVDataset(NumpyDataset): + """Create a Dataset from a CSV file.""" + def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0): + raw = np.genfromtxt( + filepath, + delimiter=delimiter, + skip_header=skip_header, + ) + data = np.delete(raw, 1, target_col) + targets = raw[:, target_col] + super().__init__(data, targets) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ea058de..66fd11e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -49,6 +49,21 @@ class TestNumpyDataset(unittest.TestCase): self.assertEqual(len(ds), 3) +class TestCSVDataset(unittest.TestCase): + def setUp(self): + data = np.random.rand(100, 4) + targets = np.random.randint(2, size=(100, 1)) + arr = np.hstack([data, targets]) + np.savetxt("./artifacts/test.csv", arr, delimiter=",") + + def test_len(self): + ds = pt.datasets.CSVDataset("./artifacts/test.csv") + self.assertEqual(len(ds), 100) + + def tearDown(self): + os.remove("./artifacts/test.csv") + + class TestSpiral(unittest.TestCase): def test_init(self): ds = pt.datasets.Spiral(num_samples=10)