Add one-hot support in functions/initializers.py
This commit is contained in:
parent
c11a3860df
commit
532f63b1de
@ -13,71 +13,84 @@ def register_initializer(function):
|
|||||||
return function
|
return function
|
||||||
|
|
||||||
|
|
||||||
def labels_from(distribution):
|
def labels_from(distribution, one_hot=True):
|
||||||
"""Takes a distribution tensor and returns a labels tensor."""
|
"""Takes a distribution tensor and returns a labels tensor."""
|
||||||
nclasses = distribution.shape[0]
|
nclasses = distribution.shape[0]
|
||||||
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
||||||
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
||||||
labels = list(chain(*llist)) # flatten using itertools.chain
|
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
||||||
return torch.tensor(labels, requires_grad=False)
|
plabels = torch.tensor(flat_llist, requires_grad=False)
|
||||||
|
if one_hot:
|
||||||
|
return torch.eye(nclasses)[plabels]
|
||||||
|
return plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def ones(x_train, y_train, prototype_distribution):
|
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
protos = torch.ones(nprotos, *x_train.shape[1:])
|
protos = torch.ones(nprotos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def zeros(x_train, y_train, prototype_distribution):
|
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def rand(x_train, y_train, prototype_distribution):
|
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
protos = torch.rand(nprotos, *x_train.shape[1:])
|
protos = torch.rand(nprotos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def randn(x_train, y_train, prototype_distribution):
|
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
protos = torch.randn(nprotos, *x_train.shape[1:])
|
protos = torch.randn(nprotos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def stratified_mean(x_train, y_train, prototype_distribution):
|
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
pdim = x_train.shape[1]
|
pdim = x_train.shape[1]
|
||||||
protos = torch.empty(nprotos, pdim)
|
protos = torch.empty(nprotos, pdim)
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
for i, l in enumerate(plabels):
|
for i, label in enumerate(plabels):
|
||||||
xl = x_train[y_train == l]
|
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||||
|
if one_hot:
|
||||||
|
nclasses = y_train.size()[1]
|
||||||
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
|
xl = x_train[matcher]
|
||||||
mean_xl = torch.mean(xl, dim=0)
|
mean_xl = torch.mean(xl, dim=0)
|
||||||
protos[i] = mean_xl
|
protos[i] = mean_xl
|
||||||
|
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def stratified_random(x_train, y_train, prototype_distribution):
|
def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
pdim = x_train.shape[1]
|
pdim = x_train.shape[1]
|
||||||
protos = torch.empty(nprotos, pdim)
|
protos = torch.empty(nprotos, pdim)
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
for i, l in enumerate(plabels):
|
for i, label in enumerate(plabels):
|
||||||
xl = x_train[y_train == l]
|
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||||
|
if one_hot:
|
||||||
|
nclasses = y_train.size()[1]
|
||||||
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
|
xl = x_train[matcher]
|
||||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||||
random_xl = xl[rand_index]
|
random_xl = xl[rand_index]
|
||||||
protos[i] = random_xl
|
protos[i] = random_xl
|
||||||
|
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user