Update datasets test suite
This commit is contained in:
parent
44e4709387
commit
4a99bcbf0d
@ -1,32 +1,97 @@
|
|||||||
"""ProtoTorch datasets test suite."""
|
"""ProtoTorch datasets test suite"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.datasets import abstract, tecator
|
import prototorch as pt
|
||||||
|
from prototorch.datasets.abstract import Dataset, ProtoDataset
|
||||||
|
|
||||||
|
|
||||||
class TestAbstract(unittest.TestCase):
|
class TestAbstract(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.ds = Dataset("./artifacts")
|
||||||
|
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.Dataset("./artifacts")[0]
|
_ = self.ds[0]
|
||||||
|
|
||||||
def test_len(self):
|
def test_len(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
len(abstract.Dataset("./artifacts"))
|
_ = len(self.ds)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del self.ds
|
||||||
|
|
||||||
|
|
||||||
class TestProtoDataset(unittest.TestCase):
|
class TestProtoDataset(unittest.TestCase):
|
||||||
def test_getitem(self):
|
|
||||||
with self.assertRaises(NotImplementedError):
|
|
||||||
abstract.ProtoDataset("./artifacts")[0]
|
|
||||||
|
|
||||||
def test_download(self):
|
def test_download(self):
|
||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
abstract.ProtoDataset("./artifacts").download()
|
_ = ProtoDataset("./artifacts", download=True)
|
||||||
|
|
||||||
|
def test_exists(self):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
_ = ProtoDataset("./artifacts", download=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNumpyDataset(unittest.TestCase):
|
||||||
|
def test_list_init(self):
|
||||||
|
ds = pt.datasets.NumpyDataset([1], [1])
|
||||||
|
self.assertEqual(len(ds), 1)
|
||||||
|
|
||||||
|
def test_numpy_init(self):
|
||||||
|
data = np.random.randn(3, 2)
|
||||||
|
targets = np.array([0, 1, 2])
|
||||||
|
ds = pt.datasets.NumpyDataset(data, targets)
|
||||||
|
self.assertEqual(len(ds), 3)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpiral(unittest.TestCase):
|
||||||
|
def test_init(self):
|
||||||
|
ds = pt.datasets.Spiral(num_samples=10)
|
||||||
|
self.assertEqual(len(ds), 10)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIris(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.ds = pt.datasets.Iris()
|
||||||
|
|
||||||
|
def test_size(self):
|
||||||
|
self.assertEqual(len(self.ds), 150)
|
||||||
|
|
||||||
|
def test_dims(self):
|
||||||
|
self.assertEqual(self.ds.data.shape[1], 4)
|
||||||
|
|
||||||
|
def test_dims_selection(self):
|
||||||
|
ds = pt.datasets.Iris(dims=[0, 1])
|
||||||
|
self.assertEqual(ds.data.shape[1], 2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlobs(unittest.TestCase):
|
||||||
|
def test_size(self):
|
||||||
|
ds = pt.datasets.Blobs(num_samples=10)
|
||||||
|
self.assertEqual(len(ds), 10)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandom(unittest.TestCase):
|
||||||
|
def test_size(self):
|
||||||
|
ds = pt.datasets.Random(num_samples=10)
|
||||||
|
self.assertEqual(len(ds), 10)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCircles(unittest.TestCase):
|
||||||
|
def test_size(self):
|
||||||
|
ds = pt.datasets.Circles(num_samples=10)
|
||||||
|
self.assertEqual(len(ds), 10)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoons(unittest.TestCase):
|
||||||
|
def test_size(self):
|
||||||
|
ds = pt.datasets.Moons(num_samples=10)
|
||||||
|
self.assertEqual(len(ds), 10)
|
||||||
|
|
||||||
|
|
||||||
class TestTecator(unittest.TestCase):
|
class TestTecator(unittest.TestCase):
|
||||||
@ -42,25 +107,25 @@ class TestTecator(unittest.TestCase):
|
|||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
self._remove_artifacts()
|
self._remove_artifacts()
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
_ = tecator.Tecator(rootdir, download=False)
|
_ = pt.datasets.Tecator(rootdir, download=False)
|
||||||
|
|
||||||
def test_download_caching(self):
|
def test_download_caching(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
_ = pt.datasets.Tecator(rootdir, download=True, verbose=False)
|
||||||
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
_ = pt.datasets.Tecator(rootdir, download=False, verbose=False)
|
||||||
|
|
||||||
def test_repr(self):
|
def test_repr(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
train = pt.datasets.Tecator(rootdir, download=True, verbose=True)
|
||||||
self.assertTrue("Split: Train" in train.__repr__())
|
self.assertTrue("Split: Train" in train.__repr__())
|
||||||
|
|
||||||
def test_download_train(self):
|
def test_download_train(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
train = tecator.Tecator(root=rootdir,
|
train = pt.datasets.Tecator(root=rootdir,
|
||||||
train=True,
|
train=True,
|
||||||
download=True,
|
download=True,
|
||||||
verbose=False)
|
verbose=False)
|
||||||
train = tecator.Tecator(root=rootdir, download=True, verbose=False)
|
train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False)
|
||||||
x_train, y_train = train.data, train.targets
|
x_train, y_train = train.data, train.targets
|
||||||
self.assertEqual(x_train.shape[0], 144)
|
self.assertEqual(x_train.shape[0], 144)
|
||||||
self.assertEqual(y_train.shape[0], 144)
|
self.assertEqual(y_train.shape[0], 144)
|
||||||
@ -68,7 +133,7 @@ class TestTecator(unittest.TestCase):
|
|||||||
|
|
||||||
def test_download_test(self):
|
def test_download_test(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
x_test, y_test = test.data, test.targets
|
x_test, y_test = test.data, test.targets
|
||||||
self.assertEqual(x_test.shape[0], 71)
|
self.assertEqual(x_test.shape[0], 71)
|
||||||
self.assertEqual(y_test.shape[0], 71)
|
self.assertEqual(y_test.shape[0], 71)
|
||||||
@ -76,20 +141,20 @@ class TestTecator(unittest.TestCase):
|
|||||||
|
|
||||||
def test_class_to_idx(self):
|
def test_class_to_idx(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
_ = test.class_to_idx
|
_ = test.class_to_idx
|
||||||
|
|
||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
x, y = test[0]
|
x, y = test[0]
|
||||||
self.assertEqual(x.shape[0], 100)
|
self.assertEqual(x.shape[0], 100)
|
||||||
self.assertIsInstance(y, int)
|
self.assertIsInstance(y, int)
|
||||||
|
|
||||||
def test_loadable_with_dataloader(self):
|
def test_loadable_with_dataloader(self):
|
||||||
rootdir = self.artifacts_dir.rpartition("/")[0]
|
rootdir = self.artifacts_dir.rpartition("/")[0]
|
||||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
pass
|
self._remove_artifacts()
|
||||||
|
Loading…
Reference in New Issue
Block a user