Minor aesthetic changes
This commit is contained in:
		@@ -1,6 +1,12 @@
 | 
				
			|||||||
"""ProtoTorch datasets."""
 | 
					"""ProtoTorch datasets"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .abstract import NumpyDataset
 | 
					from .abstract import NumpyDataset
 | 
				
			||||||
from .sklearn import Blobs, Circles, Iris, Moons, Random
 | 
					from .sklearn import (
 | 
				
			||||||
 | 
					    Blobs,
 | 
				
			||||||
 | 
					    Circles,
 | 
				
			||||||
 | 
					    Iris,
 | 
				
			||||||
 | 
					    Moons,
 | 
				
			||||||
 | 
					    Random,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from .spiral import Spiral
 | 
					from .spiral import Spiral
 | 
				
			||||||
from .tecator import Tecator
 | 
					from .tecator import Tecator
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,10 +1,11 @@
 | 
				
			|||||||
"""ProtoTorch abstract dataset classes.
 | 
					"""ProtoTorch abstract dataset classes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
 | 
					Based on `torchvision.VisionDataset` and `torchvision.MNIST`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
For the original code, see:
 | 
					For the original code, see:
 | 
				
			||||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
 | 
					https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
 | 
				
			||||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
 | 
					https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
@@ -12,15 +13,6 @@ import os
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class NumpyDataset(torch.utils.data.TensorDataset):
 | 
					 | 
				
			||||||
    """Create a PyTorch TensorDataset from NumPy arrays."""
 | 
					 | 
				
			||||||
    def __init__(self, data, targets):
 | 
					 | 
				
			||||||
        self.data = torch.Tensor(data)
 | 
					 | 
				
			||||||
        self.targets = torch.LongTensor(targets)
 | 
					 | 
				
			||||||
        tensors = [self.data, self.targets]
 | 
					 | 
				
			||||||
        super().__init__(*tensors)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class Dataset(torch.utils.data.Dataset):
 | 
					class Dataset(torch.utils.data.Dataset):
 | 
				
			||||||
    """Abstract dataset class to be inherited."""
 | 
					    """Abstract dataset class to be inherited."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -44,7 +36,7 @@ class ProtoDataset(Dataset):
 | 
				
			|||||||
    training_file = "training.pt"
 | 
					    training_file = "training.pt"
 | 
				
			||||||
    test_file = "test.pt"
 | 
					    test_file = "test.pt"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, root, train=True, download=True, verbose=True):
 | 
					    def __init__(self, root="", train=True, download=True, verbose=True):
 | 
				
			||||||
        super().__init__(root)
 | 
					        super().__init__(root)
 | 
				
			||||||
        self.train = train  # training set or test set
 | 
					        self.train = train  # training set or test set
 | 
				
			||||||
        self.verbose = verbose
 | 
					        self.verbose = verbose
 | 
				
			||||||
@@ -96,3 +88,12 @@ class ProtoDataset(Dataset):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _download(self):
 | 
					    def _download(self):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class NumpyDataset(torch.utils.data.TensorDataset):
 | 
				
			||||||
 | 
					    """Create a PyTorch TensorDataset from NumPy arrays."""
 | 
				
			||||||
 | 
					    def __init__(self, data, targets):
 | 
				
			||||||
 | 
					        self.data = torch.Tensor(data)
 | 
				
			||||||
 | 
					        self.targets = torch.LongTensor(targets)
 | 
				
			||||||
 | 
					        tensors = [self.data, self.targets]
 | 
				
			||||||
 | 
					        super().__init__(*tensors)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
"""ProtoFlow color utilities."""
 | 
					"""ProtoFlow color utilities"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def hex_to_rgb(hex_values):
 | 
					def hex_to_rgb(hex_values):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user