prototorch/tests/test_datasets.py

96 lines
3.3 KiB
Python
Raw Normal View History

2020-04-14 17:47:59 +00:00
"""ProtoTorch datasets test suite."""
import os
import shutil
import unittest
import torch
from prototorch.datasets import abstract, tecator
class TestAbstract(unittest.TestCase):
def test_getitem(self):
with self.assertRaises(NotImplementedError):
2021-04-23 15:24:53 +00:00
abstract.Dataset("./artifacts")[0]
2020-04-14 17:47:59 +00:00
def test_len(self):
with self.assertRaises(NotImplementedError):
2021-04-23 15:24:53 +00:00
len(abstract.Dataset("./artifacts"))
2020-04-14 17:47:59 +00:00
class TestProtoDataset(unittest.TestCase):
def test_getitem(self):
with self.assertRaises(NotImplementedError):
2021-04-23 15:24:53 +00:00
abstract.ProtoDataset("./artifacts")[0]
2020-04-14 17:47:59 +00:00
def test_download(self):
with self.assertRaises(NotImplementedError):
2021-04-23 15:24:53 +00:00
abstract.ProtoDataset("./artifacts").download()
2020-04-14 17:47:59 +00:00
class TestTecator(unittest.TestCase):
def setUp(self):
2021-04-23 15:24:53 +00:00
self.artifacts_dir = "./artifacts/Tecator"
2020-04-14 17:47:59 +00:00
self._remove_artifacts()
def _remove_artifacts(self):
if os.path.exists(self.artifacts_dir):
shutil.rmtree(self.artifacts_dir)
def test_download_false(self):
2021-04-23 15:24:53 +00:00
rootdir = self.artifacts_dir.rpartition("/")[0]
2020-04-14 17:47:59 +00:00
self._remove_artifacts()
with self.assertRaises(RuntimeError):
_ = tecator.Tecator(rootdir, download=False)
def test_download_caching(self):
2021-04-23 15:24:53 +00:00
rootdir = self.artifacts_dir.rpartition("/")[0]
2020-04-14 17:47:59 +00:00
_ = tecator.Tecator(rootdir, download=True, verbose=False)
_ = tecator.Tecator(rootdir, download=False, verbose=False)
def test_repr(self):
2021-04-23 15:24:53 +00:00
rootdir = self.artifacts_dir.rpartition("/")[0]
2020-04-14 17:47:59 +00:00
train = tecator.Tecator(rootdir, download=True, verbose=True)
2021-04-23 15:24:53 +00:00
self.assertTrue("Split: Train" in train.__repr__())
2020-04-14 17:47:59 +00:00
def test_download_train(self):
2021-04-23 15:24:53 +00:00
rootdir = self.artifacts_dir.rpartition("/")[0]
2020-04-14 17:47:59 +00:00
train = tecator.Tecator(root=rootdir,
train=True,
download=True,
verbose=False)
train = tecator.Tecator(root=rootdir, download=True, verbose=False)
x_train, y_train = train.data, train.targets
self.assertEqual(x_train.shape[0], 144)
self.assertEqual(y_train.shape[0], 144)
self.assertEqual(x_train.shape[1], 100)
def test_download_test(self):
2021-04-23 15:24:53 +00:00
rootdir = self.artifacts_dir.rpartition("/")[0]
2020-04-14 17:47:59 +00:00
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
x_test, y_test = test.data, test.targets
self.assertEqual(x_test.shape[0], 71)
self.assertEqual(y_test.shape[0], 71)
self.assertEqual(x_test.shape[1], 100)
def test_class_to_idx(self):
2021-04-23 15:24:53 +00:00
rootdir = self.artifacts_dir.rpartition("/")[0]
2020-04-14 17:47:59 +00:00
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
_ = test.class_to_idx
def test_getitem(self):
2021-04-23 15:24:53 +00:00
rootdir = self.artifacts_dir.rpartition("/")[0]
2020-04-14 17:47:59 +00:00
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
x, y = test[0]
self.assertEqual(x.shape[0], 100)
self.assertIsInstance(y, int)
def test_loadable_with_dataloader(self):
2021-04-23 15:24:53 +00:00
rootdir = self.artifacts_dir.rpartition("/")[0]
2020-04-14 17:47:59 +00:00
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
def tearDown(self):
pass