diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index 1d61061..096fc6f 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -1,6 +1,12 @@ -"""ProtoTorch datasets.""" +"""ProtoTorch datasets""" from .abstract import NumpyDataset -from .sklearn import Blobs, Circles, Iris, Moons, Random +from .sklearn import ( + Blobs, + Circles, + Iris, + Moons, + Random, +) from .spiral import Spiral from .tecator import Tecator diff --git a/prototorch/datasets/abstract.py b/prototorch/datasets/abstract.py index e941c95..dac8f8c 100644 --- a/prototorch/datasets/abstract.py +++ b/prototorch/datasets/abstract.py @@ -1,10 +1,11 @@ -"""ProtoTorch abstract dataset classes. +"""ProtoTorch abstract dataset classes -Based on `torchvision.VisionDataset` and `torchvision.MNIST` +Based on `torchvision.VisionDataset` and `torchvision.MNIST`. For the original code, see: https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py + """ import os @@ -12,15 +13,6 @@ import os import torch -class NumpyDataset(torch.utils.data.TensorDataset): - """Create a PyTorch TensorDataset from NumPy arrays.""" - def __init__(self, data, targets): - self.data = torch.Tensor(data) - self.targets = torch.LongTensor(targets) - tensors = [self.data, self.targets] - super().__init__(*tensors) - - class Dataset(torch.utils.data.Dataset): """Abstract dataset class to be inherited.""" @@ -44,7 +36,7 @@ class ProtoDataset(Dataset): training_file = "training.pt" test_file = "test.pt" - def __init__(self, root, train=True, download=True, verbose=True): + def __init__(self, root="", train=True, download=True, verbose=True): super().__init__(root) self.train = train # training set or test set self.verbose = verbose @@ -96,3 +88,12 @@ class ProtoDataset(Dataset): def _download(self): raise NotImplementedError + + +class NumpyDataset(torch.utils.data.TensorDataset): + """Create a PyTorch TensorDataset from NumPy arrays.""" + def __init__(self, data, targets): + self.data = torch.Tensor(data) + self.targets = torch.LongTensor(targets) + tensors = [self.data, self.targets] + super().__init__(*tensors) diff --git a/prototorch/utils/colors.py b/prototorch/utils/colors.py index 07e2d5d..61ad1a0 100644 --- a/prototorch/utils/colors.py +++ b/prototorch/utils/colors.py @@ -1,4 +1,4 @@ -"""ProtoFlow color utilities.""" +"""ProtoFlow color utilities""" def hex_to_rgb(hex_values):