Minor aesthetic changes
This commit is contained in:
parent
0b2aaa42b8
commit
44e4709387
@ -1,6 +1,12 @@
|
|||||||
"""ProtoTorch datasets."""
|
"""ProtoTorch datasets"""
|
||||||
|
|
||||||
from .abstract import NumpyDataset
|
from .abstract import NumpyDataset
|
||||||
from .sklearn import Blobs, Circles, Iris, Moons, Random
|
from .sklearn import (
|
||||||
|
Blobs,
|
||||||
|
Circles,
|
||||||
|
Iris,
|
||||||
|
Moons,
|
||||||
|
Random,
|
||||||
|
)
|
||||||
from .spiral import Spiral
|
from .spiral import Spiral
|
||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
"""ProtoTorch abstract dataset classes.
|
"""ProtoTorch abstract dataset classes
|
||||||
|
|
||||||
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
|
Based on `torchvision.VisionDataset` and `torchvision.MNIST`.
|
||||||
|
|
||||||
For the original code, see:
|
For the original code, see:
|
||||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
|
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
|
||||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
@ -12,15 +13,6 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class NumpyDataset(torch.utils.data.TensorDataset):
|
|
||||||
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
|
||||||
def __init__(self, data, targets):
|
|
||||||
self.data = torch.Tensor(data)
|
|
||||||
self.targets = torch.LongTensor(targets)
|
|
||||||
tensors = [self.data, self.targets]
|
|
||||||
super().__init__(*tensors)
|
|
||||||
|
|
||||||
|
|
||||||
class Dataset(torch.utils.data.Dataset):
|
class Dataset(torch.utils.data.Dataset):
|
||||||
"""Abstract dataset class to be inherited."""
|
"""Abstract dataset class to be inherited."""
|
||||||
|
|
||||||
@ -44,7 +36,7 @@ class ProtoDataset(Dataset):
|
|||||||
training_file = "training.pt"
|
training_file = "training.pt"
|
||||||
test_file = "test.pt"
|
test_file = "test.pt"
|
||||||
|
|
||||||
def __init__(self, root, train=True, download=True, verbose=True):
|
def __init__(self, root="", train=True, download=True, verbose=True):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
self.train = train # training set or test set
|
self.train = train # training set or test set
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
@ -96,3 +88,12 @@ class ProtoDataset(Dataset):
|
|||||||
|
|
||||||
def _download(self):
|
def _download(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||||
|
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||||
|
def __init__(self, data, targets):
|
||||||
|
self.data = torch.Tensor(data)
|
||||||
|
self.targets = torch.LongTensor(targets)
|
||||||
|
tensors = [self.data, self.targets]
|
||||||
|
super().__init__(*tensors)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""ProtoFlow color utilities."""
|
"""ProtoFlow color utilities"""
|
||||||
|
|
||||||
|
|
||||||
def hex_to_rgb(hex_values):
|
def hex_to_rgb(hex_values):
|
||||||
|
Loading…
Reference in New Issue
Block a user