[FEATURE] Change NumpyDataset.data to torch.Tensor

This commit is contained in:
Alexander Engelsberger 2021-06-01 17:17:42 +02:00
parent 2722d976f5
commit 4ca581909a

View File

@ -15,9 +15,9 @@ import torch
class NumpyDataset(torch.utils.data.TensorDataset): class NumpyDataset(torch.utils.data.TensorDataset):
"""Create a PyTorch TensorDataset from NumPy arrays.""" """Create a PyTorch TensorDataset from NumPy arrays."""
def __init__(self, data, targets): def __init__(self, data, targets):
self.data = data self.data = torch.Tensor(data)
self.targets = targets self.targets = torch.LongTensor(targets)
tensors = [torch.Tensor(data), torch.Tensor(targets)] tensors = [self.data, self.targets]
super().__init__(*tensors) super().__init__(*tensors)