[BUGFIX] examples/dynamic_pruning.py works again

This commit is contained in:
Jensun Ravichandran 2021-06-14 20:31:39 +02:00
parent d2856383e2
commit 6197d7d5d6
2 changed files with 2 additions and 2 deletions

View File

@ -36,7 +36,7 @@ if __name__ == "__main__":
# Initialize the model # Initialize the model
model = pt.models.CELVQ( model = pt.models.CELVQ(
hparams, hparams,
prototype_initializer=pt.components.Ones(2, scale=3), prototypes_initializer=pt.initializers.FVCI(2, 3.0),
) )
# Compute intermediate input and output sizes # Compute intermediate input and output sizes

View File

@ -20,7 +20,7 @@ class CELVQ(GLVQ):
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self.compute_distances(x) # [None, num_protos] out = self.compute_distances(x) # [None, num_protos]
plabels = self.proto_layer.component_labels plabels = self.proto_layer.labels
winning = stratified_min_pooling(out, plabels) # [None, num_classes] winning = stratified_min_pooling(out, plabels) # [None, num_classes]
probs = -1.0 * winning probs = -1.0 * winning
batch_loss = self.loss(probs, y.long()) batch_loss = self.loss(probs, y.long())