feat: add CSVDataset
				
					
				
			This commit is contained in:
		| @@ -1,6 +1,6 @@ | ||||
| """ProtoTorch datasets""" | ||||
|  | ||||
| from .abstract import NumpyDataset | ||||
| from .abstract import CSVDataset, NumpyDataset | ||||
| from .sklearn import ( | ||||
|     Blobs, | ||||
|     Circles, | ||||
|   | ||||
| @@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py | ||||
|  | ||||
| import os | ||||
|  | ||||
| import numpy as np | ||||
| import torch | ||||
|  | ||||
|  | ||||
| @@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset): | ||||
|         self.targets = torch.LongTensor(targets) | ||||
|         tensors = [self.data, self.targets] | ||||
|         super().__init__(*tensors) | ||||
|  | ||||
|  | ||||
| class CSVDataset(NumpyDataset): | ||||
|     """Create a Dataset from a CSV file.""" | ||||
|     def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0): | ||||
|         raw = np.genfromtxt( | ||||
|             filepath, | ||||
|             delimiter=delimiter, | ||||
|             skip_header=skip_header, | ||||
|         ) | ||||
|         data = np.delete(raw, 1, target_col) | ||||
|         targets = raw[:, target_col] | ||||
|         super().__init__(data, targets) | ||||
|   | ||||
| @@ -49,6 +49,21 @@ class TestNumpyDataset(unittest.TestCase): | ||||
|         self.assertEqual(len(ds), 3) | ||||
|  | ||||
|  | ||||
| class TestCSVDataset(unittest.TestCase): | ||||
|     def setUp(self): | ||||
|         data = np.random.rand(100, 4) | ||||
|         targets = np.random.randint(2, size=(100, 1)) | ||||
|         arr = np.hstack([data, targets]) | ||||
|         np.savetxt("./artifacts/test.csv", arr, delimiter=",") | ||||
|  | ||||
|     def test_len(self): | ||||
|         ds = pt.datasets.CSVDataset("./artifacts/test.csv") | ||||
|         self.assertEqual(len(ds), 100) | ||||
|  | ||||
|     def tearDown(self): | ||||
|         os.remove("./artifacts/test.csv") | ||||
|  | ||||
|  | ||||
| class TestSpiral(unittest.TestCase): | ||||
|     def test_init(self): | ||||
|         ds = pt.datasets.Spiral(num_samples=10) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user