feat: add CSVDataset
This commit is contained in:
parent
eb79b703d8
commit
fdb9a7c66d
@ -1,6 +1,6 @@
|
|||||||
"""ProtoTorch datasets"""
|
"""ProtoTorch datasets"""
|
||||||
|
|
||||||
from .abstract import NumpyDataset
|
from .abstract import CSVDataset, NumpyDataset
|
||||||
from .sklearn import (
|
from .sklearn import (
|
||||||
Blobs,
|
Blobs,
|
||||||
Circles,
|
Circles,
|
||||||
|
@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset):
|
|||||||
self.targets = torch.LongTensor(targets)
|
self.targets = torch.LongTensor(targets)
|
||||||
tensors = [self.data, self.targets]
|
tensors = [self.data, self.targets]
|
||||||
super().__init__(*tensors)
|
super().__init__(*tensors)
|
||||||
|
|
||||||
|
|
||||||
|
class CSVDataset(NumpyDataset):
|
||||||
|
"""Create a Dataset from a CSV file."""
|
||||||
|
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
|
||||||
|
raw = np.genfromtxt(
|
||||||
|
filepath,
|
||||||
|
delimiter=delimiter,
|
||||||
|
skip_header=skip_header,
|
||||||
|
)
|
||||||
|
data = np.delete(raw, 1, target_col)
|
||||||
|
targets = raw[:, target_col]
|
||||||
|
super().__init__(data, targets)
|
||||||
|
@ -49,6 +49,21 @@ class TestNumpyDataset(unittest.TestCase):
|
|||||||
self.assertEqual(len(ds), 3)
|
self.assertEqual(len(ds), 3)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCSVDataset(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
data = np.random.rand(100, 4)
|
||||||
|
targets = np.random.randint(2, size=(100, 1))
|
||||||
|
arr = np.hstack([data, targets])
|
||||||
|
np.savetxt("./artifacts/test.csv", arr, delimiter=",")
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
ds = pt.datasets.CSVDataset("./artifacts/test.csv")
|
||||||
|
self.assertEqual(len(ds), 100)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
os.remove("./artifacts/test.csv")
|
||||||
|
|
||||||
|
|
||||||
class TestSpiral(unittest.TestCase):
|
class TestSpiral(unittest.TestCase):
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
ds = pt.datasets.Spiral(num_samples=10)
|
ds = pt.datasets.Spiral(num_samples=10)
|
||||||
|
Loading…
Reference in New Issue
Block a user