[FEATURE] Add more initializers

This commit is contained in:
Jensun Ravichandran
2021-06-14 19:53:02 +02:00
parent 549e6a10c1
commit fc9edeaa97
3 changed files with 198 additions and 102 deletions

View File

@@ -67,17 +67,19 @@ def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
elif isinstance(user_distribution, list):
return distribution_from_list(user_distribution, clabels)
else:
msg = f"`distribution` not understood." \
msg = f"`distribution` was not understood." \
f"You have provided: {user_distribution}."
raise ValueError(msg)
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
"""Return data and target as torch tensors."""
if isinstance(data_arg, Dataset):
ds_size = len(data_arg)
data_arg = DataLoader(data_arg, batch_size=ds_size)
loader = DataLoader(data_arg, batch_size=ds_size)
data, targets = next(iter(loader))
if isinstance(data_arg, DataLoader):
elif isinstance(data_arg, DataLoader):
data = torch.tensor([])
targets = torch.tensor([])
for x, y in data_arg:
@@ -87,11 +89,11 @@ def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
assert len(data_arg) == 2
data, targets = data_arg
if not isinstance(data, torch.Tensor):
wmsg = f"Converting data to {torch.Tensor}."
wmsg = f"Converting data to {torch.Tensor}..."
warnings.warn(wmsg)
data = torch.Tensor(data)
if not isinstance(targets, torch.LongTensor):
wmsg = f"Converting targets to {torch.LongTensor}."
wmsg = f"Converting targets to {torch.LongTensor}..."
warnings.warn(wmsg)
targets = torch.LongTensor(targets)
return data, targets