diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 66fd11e..5766c31 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -54,6 +54,8 @@ class TestCSVDataset(unittest.TestCase): data = np.random.rand(100, 4) targets = np.random.randint(2, size=(100, 1)) arr = np.hstack([data, targets]) + if not os.path.exists("./artifacts"): + os.mkdir("./artifacts") np.savetxt("./artifacts/test.csv", arr, delimiter=",") def test_len(self):