NumpyDataset now has data and targets properties

This commit is contained in:
Jensun Ravichandran 2021-05-18 19:38:40 +02:00
parent 736d9a6349
commit ee42fd68b1

View File

@ -14,8 +14,10 @@ import torch
class NumpyDataset(torch.utils.data.TensorDataset): class NumpyDataset(torch.utils.data.TensorDataset):
"""Create a PyTorch TensorDataset from NumPy arrays.""" """Create a PyTorch TensorDataset from NumPy arrays."""
def __init__(self, *arrays): def __init__(self, data, targets):
tensors = [torch.Tensor(arr) for arr in arrays] self.data = data
self.targets = targets
tensors = [torch.Tensor(data), torch.Tensor(targets)]
super().__init__(*tensors) super().__init__(*tensors)