[BUGFIX] Stratified functions work on GPU now
This commit is contained in:
@@ -39,7 +39,7 @@ def stratify_with(values: torch.Tensor,
|
|||||||
# skip if stratification is trivial
|
# skip if stratification is trivial
|
||||||
return values
|
return values
|
||||||
batch_size = values.size()[0]
|
batch_size = values.size()[0]
|
||||||
winning_values = torch.zeros(num_classes, batch_size)
|
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
||||||
filler = torch.full_like(values.T, fill_value=fill_value)
|
filler = torch.full_like(values.T, fill_value=fill_value)
|
||||||
for i, cl in enumerate(clabels):
|
for i, cl in enumerate(clabels):
|
||||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||||
|
Reference in New Issue
Block a user