prototorch/tests/test_datasets.py
2021-06-18 19:10:29 +02:00

161 lines
5.0 KiB
Python

"""ProtoTorch datasets test suite"""
import os
import shutil
import unittest
import numpy as np
import torch
import prototorch as pt
from prototorch.datasets.abstract import Dataset, ProtoDataset
class TestAbstract(unittest.TestCase):
def setUp(self):
self.ds = Dataset("./artifacts")
def test_getitem(self):
with self.assertRaises(NotImplementedError):
_ = self.ds[0]
def test_len(self):
with self.assertRaises(NotImplementedError):
_ = len(self.ds)
def tearDown(self):
del self.ds
class TestProtoDataset(unittest.TestCase):
def test_download(self):
with self.assertRaises(NotImplementedError):
_ = ProtoDataset("./artifacts", download=True)
def test_exists(self):
with self.assertRaises(RuntimeError):
_ = ProtoDataset("./artifacts", download=False)
class TestNumpyDataset(unittest.TestCase):
def test_list_init(self):
ds = pt.datasets.NumpyDataset([1], [1])
self.assertEqual(len(ds), 1)
def test_numpy_init(self):
data = np.random.randn(3, 2)
targets = np.array([0, 1, 2])
ds = pt.datasets.NumpyDataset(data, targets)
self.assertEqual(len(ds), 3)
class TestSpiral(unittest.TestCase):
def test_init(self):
ds = pt.datasets.Spiral(num_samples=10)
self.assertEqual(len(ds), 10)
class TestIris(unittest.TestCase):
def setUp(self):
self.ds = pt.datasets.Iris()
def test_size(self):
self.assertEqual(len(self.ds), 150)
def test_dims(self):
self.assertEqual(self.ds.data.shape[1], 4)
def test_dims_selection(self):
ds = pt.datasets.Iris(dims=[0, 1])
self.assertEqual(ds.data.shape[1], 2)
class TestBlobs(unittest.TestCase):
def test_size(self):
ds = pt.datasets.Blobs(num_samples=10)
self.assertEqual(len(ds), 10)
class TestRandom(unittest.TestCase):
def test_size(self):
ds = pt.datasets.Random(num_samples=10)
self.assertEqual(len(ds), 10)
class TestCircles(unittest.TestCase):
def test_size(self):
ds = pt.datasets.Circles(num_samples=10)
self.assertEqual(len(ds), 10)
class TestMoons(unittest.TestCase):
def test_size(self):
ds = pt.datasets.Moons(num_samples=10)
self.assertEqual(len(ds), 10)
# class TestTecator(unittest.TestCase):
# def setUp(self):
# self.artifacts_dir = "./artifacts/Tecator"
# self._remove_artifacts()
# def _remove_artifacts(self):
# if os.path.exists(self.artifacts_dir):
# shutil.rmtree(self.artifacts_dir)
# def test_download_false(self):
# rootdir = self.artifacts_dir.rpartition("/")[0]
# self._remove_artifacts()
# with self.assertRaises(RuntimeError):
# _ = pt.datasets.Tecator(rootdir, download=False)
# def test_download_caching(self):
# rootdir = self.artifacts_dir.rpartition("/")[0]
# _ = pt.datasets.Tecator(rootdir, download=True, verbose=False)
# _ = pt.datasets.Tecator(rootdir, download=False, verbose=False)
# def test_repr(self):
# rootdir = self.artifacts_dir.rpartition("/")[0]
# train = pt.datasets.Tecator(rootdir, download=True, verbose=True)
# self.assertTrue("Split: Train" in train.__repr__())
# def test_download_train(self):
# rootdir = self.artifacts_dir.rpartition("/")[0]
# train = pt.datasets.Tecator(root=rootdir,
# train=True,
# download=True,
# verbose=False)
# train = pt.datasets.Tecator(root=rootdir, download=True, verbose=False)
# x_train, y_train = train.data, train.targets
# self.assertEqual(x_train.shape[0], 144)
# self.assertEqual(y_train.shape[0], 144)
# self.assertEqual(x_train.shape[1], 100)
# def test_download_test(self):
# rootdir = self.artifacts_dir.rpartition("/")[0]
# test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
# x_test, y_test = test.data, test.targets
# self.assertEqual(x_test.shape[0], 71)
# self.assertEqual(y_test.shape[0], 71)
# self.assertEqual(x_test.shape[1], 100)
# def test_class_to_idx(self):
# rootdir = self.artifacts_dir.rpartition("/")[0]
# test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
# _ = test.class_to_idx
# def test_getitem(self):
# rootdir = self.artifacts_dir.rpartition("/")[0]
# test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
# x, y = test[0]
# self.assertEqual(x.shape[0], 100)
# self.assertIsInstance(y, int)
# def test_loadable_with_dataloader(self):
# rootdir = self.artifacts_dir.rpartition("/")[0]
# test = pt.datasets.Tecator(root=rootdir, train=False, verbose=False)
# _ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
# def tearDown(self):
# self._remove_artifacts()