feat: add CSVDataset
This commit is contained in:
parent
eb79b703d8
commit
fdb9a7c66d
@ -1,6 +1,6 @@
|
||||
"""ProtoTorch datasets"""
|
||||
|
||||
from .abstract import NumpyDataset
|
||||
from .abstract import CSVDataset, NumpyDataset
|
||||
from .sklearn import (
|
||||
Blobs,
|
||||
Circles,
|
||||
|
@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset):
|
||||
self.targets = torch.LongTensor(targets)
|
||||
tensors = [self.data, self.targets]
|
||||
super().__init__(*tensors)
|
||||
|
||||
|
||||
class CSVDataset(NumpyDataset):
|
||||
"""Create a Dataset from a CSV file."""
|
||||
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
|
||||
raw = np.genfromtxt(
|
||||
filepath,
|
||||
delimiter=delimiter,
|
||||
skip_header=skip_header,
|
||||
)
|
||||
data = np.delete(raw, 1, target_col)
|
||||
targets = raw[:, target_col]
|
||||
super().__init__(data, targets)
|
||||
|
@ -49,6 +49,21 @@ class TestNumpyDataset(unittest.TestCase):
|
||||
self.assertEqual(len(ds), 3)
|
||||
|
||||
|
||||
class TestCSVDataset(unittest.TestCase):
|
||||
def setUp(self):
|
||||
data = np.random.rand(100, 4)
|
||||
targets = np.random.randint(2, size=(100, 1))
|
||||
arr = np.hstack([data, targets])
|
||||
np.savetxt("./artifacts/test.csv", arr, delimiter=",")
|
||||
|
||||
def test_len(self):
|
||||
ds = pt.datasets.CSVDataset("./artifacts/test.csv")
|
||||
self.assertEqual(len(ds), 100)
|
||||
|
||||
def tearDown(self):
|
||||
os.remove("./artifacts/test.csv")
|
||||
|
||||
|
||||
class TestSpiral(unittest.TestCase):
|
||||
def test_init(self):
|
||||
ds = pt.datasets.Spiral(num_samples=10)
|
||||
|
Loading…
Reference in New Issue
Block a user