Fix generator bug in stratified_random initializer

This commit is contained in:
blackfly 2020-04-14 19:51:54 +02:00
parent 2b82830590
commit 3cfbc49254

View File

@ -69,16 +69,13 @@ def stratified_mean(x_train, y_train, prototype_distribution):
@register_initializer
def stratified_random(x_train, y_train, prototype_distribution):
gen = torch.manual_seed(torch.initial_seed())
nprotos = torch.sum(prototype_distribution)
pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim)
plabels = labels_from(prototype_distribution)
for i, l in enumerate(plabels):
xl = x_train[y_train == l]
rand_index = torch.zeros(1).long().random_(0,
xl.shape[1] - 1,
generator=gen)
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
random_xl = xl[rand_index]
protos[i] = random_xl
return protos, plabels