feat: add CSVDataset

This commit is contained in:
Jensun Ravichandran
2021-07-04 16:30:01 +02:00
parent eb79b703d8
commit fdb9a7c66d
3 changed files with 30 additions and 1 deletions

View File

@@ -49,6 +49,21 @@ class TestNumpyDataset(unittest.TestCase):
self.assertEqual(len(ds), 3)
class TestCSVDataset(unittest.TestCase):
def setUp(self):
data = np.random.rand(100, 4)
targets = np.random.randint(2, size=(100, 1))
arr = np.hstack([data, targets])
np.savetxt("./artifacts/test.csv", arr, delimiter=",")
def test_len(self):
ds = pt.datasets.CSVDataset("./artifacts/test.csv")
self.assertEqual(len(ds), 100)
def tearDown(self):
os.remove("./artifacts/test.csv")
class TestSpiral(unittest.TestCase):
def test_init(self):
ds = pt.datasets.Spiral(num_samples=10)