[BUGFIX] examples/dynamic_pruning.py
works again
This commit is contained in:
parent
d2856383e2
commit
6197d7d5d6
@ -36,7 +36,7 @@ if __name__ == "__main__":
|
||||
# Initialize the model
|
||||
model = pt.models.CELVQ(
|
||||
hparams,
|
||||
prototype_initializer=pt.components.Ones(2, scale=3),
|
||||
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
|
@ -20,7 +20,7 @@ class CELVQ(GLVQ):
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self.compute_distances(x) # [None, num_protos]
|
||||
plabels = self.proto_layer.component_labels
|
||||
plabels = self.proto_layer.labels
|
||||
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
|
||||
probs = -1.0 * winning
|
||||
batch_loss = self.loss(probs, y.long())
|
||||
|
Loading…
Reference in New Issue
Block a user