From cf7d7b5d9dd12db4dd625c4018a614916d4c9a7a Mon Sep 17 00:00:00 2001 From: blackfly Date: Tue, 14 Apr 2020 19:47:59 +0200 Subject: [PATCH] Add tests/test_datasets.py --- tests/test_datasets.py | 95 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/test_datasets.py diff --git a/tests/test_datasets.py b/tests/test_datasets.py new file mode 100644 index 0000000..2ee3fd3 --- /dev/null +++ b/tests/test_datasets.py @@ -0,0 +1,95 @@ +"""ProtoTorch datasets test suite.""" + +import os +import shutil +import unittest + +import torch + +from prototorch.datasets import abstract, tecator + + +class TestAbstract(unittest.TestCase): + def test_getitem(self): + with self.assertRaises(NotImplementedError): + abstract.Dataset('./artifacts')[0] + + def test_len(self): + with self.assertRaises(NotImplementedError): + len(abstract.Dataset('./artifacts')) + + +class TestProtoDataset(unittest.TestCase): + def test_getitem(self): + with self.assertRaises(NotImplementedError): + abstract.ProtoDataset('./artifacts')[0] + + def test_download(self): + with self.assertRaises(NotImplementedError): + abstract.ProtoDataset('./artifacts').download() + + +class TestTecator(unittest.TestCase): + def setUp(self): + self.artifacts_dir = './artifacts/Tecator' + self._remove_artifacts() + + def _remove_artifacts(self): + if os.path.exists(self.artifacts_dir): + shutil.rmtree(self.artifacts_dir) + + def test_download_false(self): + rootdir = self.artifacts_dir.rpartition('/')[0] + self._remove_artifacts() + with self.assertRaises(RuntimeError): + _ = tecator.Tecator(rootdir, download=False) + + def test_download_caching(self): + rootdir = self.artifacts_dir.rpartition('/')[0] + _ = tecator.Tecator(rootdir, download=True, verbose=False) + _ = tecator.Tecator(rootdir, download=False, verbose=False) + + def test_repr(self): + rootdir = self.artifacts_dir.rpartition('/')[0] + train = tecator.Tecator(rootdir, download=True, verbose=True) + self.assertTrue('Split: Train' in train.__repr__()) + + def test_download_train(self): + rootdir = self.artifacts_dir.rpartition('/')[0] + train = tecator.Tecator(root=rootdir, + train=True, + download=True, + verbose=False) + train = tecator.Tecator(root=rootdir, download=True, verbose=False) + x_train, y_train = train.data, train.targets + self.assertEqual(x_train.shape[0], 144) + self.assertEqual(y_train.shape[0], 144) + self.assertEqual(x_train.shape[1], 100) + + def test_download_test(self): + rootdir = self.artifacts_dir.rpartition('/')[0] + test = tecator.Tecator(root=rootdir, train=False, verbose=False) + x_test, y_test = test.data, test.targets + self.assertEqual(x_test.shape[0], 71) + self.assertEqual(y_test.shape[0], 71) + self.assertEqual(x_test.shape[1], 100) + + def test_class_to_idx(self): + rootdir = self.artifacts_dir.rpartition('/')[0] + test = tecator.Tecator(root=rootdir, train=False, verbose=False) + _ = test.class_to_idx + + def test_getitem(self): + rootdir = self.artifacts_dir.rpartition('/')[0] + test = tecator.Tecator(root=rootdir, train=False, verbose=False) + x, y = test[0] + self.assertEqual(x.shape[0], 100) + self.assertIsInstance(y, int) + + def test_loadable_with_dataloader(self): + rootdir = self.artifacts_dir.rpartition('/')[0] + test = tecator.Tecator(root=rootdir, train=False, verbose=False) + _ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True) + + def tearDown(self): + pass