[BUGFIX] examples/dynamic_pruning.py works again

This commit is contained in:
Jensun Ravichandran 2021-06-14 20:31:39 +02:00
parent d2856383e2
commit 6197d7d5d6
2 changed files with 2 additions and 2 deletions

View File

@ -36,7 +36,7 @@ if __name__ == "__main__":
# Initialize the model
model = pt.models.CELVQ(
hparams,
prototype_initializer=pt.components.Ones(2, scale=3),
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
)
# Compute intermediate input and output sizes

View File

@ -20,7 +20,7 @@ class CELVQ(GLVQ):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x) # [None, num_protos]
plabels = self.proto_layer.component_labels
plabels = self.proto_layer.labels
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning
batch_loss = self.loss(probs, y.long())