Minor aesthetic changes

This commit is contained in:
Jensun Ravichandran 2021-06-11 23:42:19 +02:00
parent 0b2aaa42b8
commit 44e4709387
3 changed files with 22 additions and 15 deletions

View File

@ -1,6 +1,12 @@
"""ProtoTorch datasets."""
"""ProtoTorch datasets"""
from .abstract import NumpyDataset
from .sklearn import Blobs, Circles, Iris, Moons, Random
from .sklearn import (
Blobs,
Circles,
Iris,
Moons,
Random,
)
from .spiral import Spiral
from .tecator import Tecator

View File

@ -1,10 +1,11 @@
"""ProtoTorch abstract dataset classes.
"""ProtoTorch abstract dataset classes
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
Based on `torchvision.VisionDataset` and `torchvision.MNIST`.
For the original code, see:
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
"""
import os
@ -12,15 +13,6 @@ import os
import torch
class NumpyDataset(torch.utils.data.TensorDataset):
"""Create a PyTorch TensorDataset from NumPy arrays."""
def __init__(self, data, targets):
self.data = torch.Tensor(data)
self.targets = torch.LongTensor(targets)
tensors = [self.data, self.targets]
super().__init__(*tensors)
class Dataset(torch.utils.data.Dataset):
"""Abstract dataset class to be inherited."""
@ -44,7 +36,7 @@ class ProtoDataset(Dataset):
training_file = "training.pt"
test_file = "test.pt"
def __init__(self, root, train=True, download=True, verbose=True):
def __init__(self, root="", train=True, download=True, verbose=True):
super().__init__(root)
self.train = train # training set or test set
self.verbose = verbose
@ -96,3 +88,12 @@ class ProtoDataset(Dataset):
def _download(self):
raise NotImplementedError
class NumpyDataset(torch.utils.data.TensorDataset):
"""Create a PyTorch TensorDataset from NumPy arrays."""
def __init__(self, data, targets):
self.data = torch.Tensor(data)
self.targets = torch.LongTensor(targets)
tensors = [self.data, self.targets]
super().__init__(*tensors)

View File

@ -1,4 +1,4 @@
"""ProtoFlow color utilities."""
"""ProtoFlow color utilities"""
def hex_to_rgb(hex_values):