Accept dataloaders for component initialization

This commit is contained in:
Jensun Ravichandran 2021-05-21 11:59:57 +02:00
parent 30adbf705c
commit e3f8828da4

View File

@ -10,6 +10,12 @@ from torch.utils.data import DataLoader, Dataset
def parse_data_arg(data):
if isinstance(data, Dataset):
data, labels = next(iter(DataLoader(data, batch_size=len(data))))
elif isinstance(data, DataLoader):
data = torch.tensor([])
labels = torch.tensor([])
for x, y in data:
data = torch.cat([data, x])
labels = torch.cat([labels, y])
else:
data, labels = data
if not isinstance(data, torch.Tensor):