[BUGFIX] Initializers can handle Dataloaders now

This commit is contained in:
Alexander Engelsberger 2021-05-21 16:00:20 +02:00
parent 14508f0600
commit ee30d4da5b

View File

@ -7,17 +7,18 @@ import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
def parse_data_arg(data): def parse_data_arg(data_arg):
if isinstance(data, Dataset): if isinstance(data_arg, Dataset):
data, labels = next(iter(DataLoader(data, batch_size=len(data)))) data_arg = DataLoader(data_arg, batch_size=len(data_arg))
elif isinstance(data, DataLoader):
if isinstance(data_arg, DataLoader):
data = torch.tensor([]) data = torch.tensor([])
labels = torch.tensor([]) labels = torch.tensor([])
for x, y in data: for x, y in data_arg:
data = torch.cat([data, x]) data = torch.cat([data, x])
labels = torch.cat([labels, y]) labels = torch.cat([labels, y])
else: else:
data, labels = data data, labels = data_arg
if not isinstance(data, torch.Tensor): if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}." wmsg = f"Converting data to {torch.Tensor}."
warnings.warn(wmsg) warnings.warn(wmsg)