Update datasets test suite

This commit is contained in:
Jensun Ravichandran 2021-06-11 23:43:18 +02:00
parent 44e4709387
commit 4a99bcbf0d

View File

@ -1,32 +1,97 @@
"""ProtoTorch datasets test suite.""" """ProtoTorch datasets test suite"""
import os import os
import shutil import shutil
import unittest import unittest
import numpy as np
import torch import torch
from prototorch.datasets import abstract, tecator import prototorch as pt
from prototorch.datasets.abstract import Dataset, ProtoDataset
class TestAbstract(unittest.TestCase): class TestAbstract(unittest.TestCase):
def setUp(self):
self.ds = Dataset("./artifacts")
def test_getitem(self): def test_getitem(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
abstract.Dataset("./artifacts")[0] _ = self.ds[0]
def test_len(self): def test_len(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
len(abstract.Dataset("./artifacts")) _ = len(self.ds)
def tearDown(self):
del self.ds
class TestProtoDataset(unittest.TestCase): class TestProtoDataset(unittest.TestCase):
def test_getitem(self):
with self.assertRaises(NotImplementedError):
abstract.ProtoDataset("./artifacts")[0]
def test_download(self): def test_download(self):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
abstract.ProtoDataset("./artifacts").download() _ = ProtoDataset("./artifacts", download=True)
def test_exists(self):
with self.assertRaises(RuntimeError):
_ = ProtoDataset("./artifacts", download=False)
class TestNumpyDataset(unittest.TestCase):
def test_list_init(self):
ds = pt.datasets.NumpyDataset([1], [1])
self.assertEqual(len(ds), 1)
def test_numpy_init(self):
data = np.random.randn(3, 2)
targets = np.array([0, 1, 2])
ds = pt.datasets.NumpyDataset(data, targets)
self.assertEqual(len(ds), 3)
class TestSpiral(unittest.TestCase):
def test_init(self):
ds = pt.datasets.Spiral(num_samples=10)
self.assertEqual(len(ds), 10)
class TestIris(unittest.TestCase):
def setUp(self):
self.ds = pt.datasets.Iris()
def test_size(self):
self.assertEqual(len(self.ds), 150)
def test_dims(self):
self.assertEqual(self.ds.data.shape[1], 4)
def test_dims_selection(self):
ds = pt.datasets.Iris(dims=[0, 1])
self.assertEqual(ds.data.shape[1], 2)
class TestBlobs(unittest.TestCase):
def test_size(self):
ds = pt.datasets.Blobs(num_samples=10)
self.assertEqual(len(ds), 10)
class TestRandom(unittest.TestCase):
def test_size(self):
ds = pt.datasets.Random(num_samples=10)
self.assertEqual(len(ds), 10)
class TestCircles(unittest.TestCase):
def test_size(self):
ds = pt.datasets.Circles(num_samples=10)
self.assertEqual(len(ds), 10)
class TestMoons(unittest.TestCase):
def test_size(self):
ds = pt.datasets.Moons(num_samples=10)
self.assertEqual(len(ds), 10)
class TestTecator(unittest.TestCase): class TestTecator(unittest.TestCase):
@ -42,25 +107,25 @@ class TestTecator(unittest.TestCase):
rootdir = self.artifacts_dir.rpartition("/")[0] rootdir = self.artifacts_dir.rpartition("/")[0]
self._remove_artifacts() self._remove_artifacts()
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
_ = tecator.Tecator(rootdir, download=False) _ = pt.datasets.Tecator(rootdir, download=False)
def test_download_caching(self): def test_download_caching(self):
rootdir = self.artifacts_dir.rpartition("/")[0] rootdir = self.artifacts_dir.rpartition("/")[0]
_ = tecator.Tecator(rootdir, download=True, verbose=False) _ = pt.datasets.Tecator(rootdir, download=True, verbose=False)
_ = tecator.Tecator(rootdir, download=False, verbose=False) _ = pt.datasets.Tecator(rootdir, download=False, verbose=False)
def test_repr(self): def test_repr(self):
rootdir = self.artifacts_dir.rpartition("/")[0] rootdir = self.artifacts_dir.rpartition("/")[0]
train = tecator.Tecator(rootdir, download=True, verbose=True) train = pt.datasets.Tecator(rootdir, download=True, verbose=True)
self.assertTrue("Split: Train" in train.__repr__()) self.assertTrue("Split: Train" in train.__repr__())
def test_download_train(self): def test_download_train(self):
rootdir = self.artifacts_dir.rpartition("/")[0] rootdir = self.artifacts_dir.rpartition("/")[0]
train = tecator.Tecator(root=rootdir, train = pt.datasets.Tecator(root=rootdir,
train=True, train=True,
download=True, download=True,
verbose=False) verbose=False)
train = tecator.Tecator(root=rootdir, download=True, verbose=False) train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False)
x_train, y_train = train.data, train.targets x_train, y_train = train.data, train.targets
self.assertEqual(x_train.shape[0], 144) self.assertEqual(x_train.shape[0], 144)
self.assertEqual(y_train.shape[0], 144) self.assertEqual(y_train.shape[0], 144)
@ -68,7 +133,7 @@ class TestTecator(unittest.TestCase):
def test_download_test(self): def test_download_test(self):
rootdir = self.artifacts_dir.rpartition("/")[0] rootdir = self.artifacts_dir.rpartition("/")[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False) test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
x_test, y_test = test.data, test.targets x_test, y_test = test.data, test.targets
self.assertEqual(x_test.shape[0], 71) self.assertEqual(x_test.shape[0], 71)
self.assertEqual(y_test.shape[0], 71) self.assertEqual(y_test.shape[0], 71)
@ -76,20 +141,20 @@ class TestTecator(unittest.TestCase):
def test_class_to_idx(self): def test_class_to_idx(self):
rootdir = self.artifacts_dir.rpartition("/")[0] rootdir = self.artifacts_dir.rpartition("/")[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False) test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
_ = test.class_to_idx _ = test.class_to_idx
def test_getitem(self): def test_getitem(self):
rootdir = self.artifacts_dir.rpartition("/")[0] rootdir = self.artifacts_dir.rpartition("/")[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False) test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
x, y = test[0] x, y = test[0]
self.assertEqual(x.shape[0], 100) self.assertEqual(x.shape[0], 100)
self.assertIsInstance(y, int) self.assertIsInstance(y, int)
def test_loadable_with_dataloader(self): def test_loadable_with_dataloader(self):
rootdir = self.artifacts_dir.rpartition("/")[0] rootdir = self.artifacts_dir.rpartition("/")[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False) test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True) _ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
def tearDown(self): def tearDown(self):
pass self._remove_artifacts()