feat: add CSVDataset
This commit is contained in:
@@ -49,6 +49,21 @@ class TestNumpyDataset(unittest.TestCase):
|
||||
self.assertEqual(len(ds), 3)
|
||||
|
||||
|
||||
class TestCSVDataset(unittest.TestCase):
|
||||
def setUp(self):
|
||||
data = np.random.rand(100, 4)
|
||||
targets = np.random.randint(2, size=(100, 1))
|
||||
arr = np.hstack([data, targets])
|
||||
np.savetxt("./artifacts/test.csv", arr, delimiter=",")
|
||||
|
||||
def test_len(self):
|
||||
ds = pt.datasets.CSVDataset("./artifacts/test.csv")
|
||||
self.assertEqual(len(ds), 100)
|
||||
|
||||
def tearDown(self):
|
||||
os.remove("./artifacts/test.csv")
|
||||
|
||||
|
||||
class TestSpiral(unittest.TestCase):
|
||||
def test_init(self):
|
||||
ds = pt.datasets.Spiral(num_samples=10)
|
||||
|
Reference in New Issue
Block a user