[BUGFIX] examples/dynamic_pruning.py
works again
This commit is contained in:
parent
d2856383e2
commit
6197d7d5d6
@ -36,7 +36,7 @@ if __name__ == "__main__":
|
|||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.CELVQ(
|
model = pt.models.CELVQ(
|
||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=pt.components.Ones(2, scale=3),
|
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
|
@ -20,7 +20,7 @@ class CELVQ(GLVQ):
|
|||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x) # [None, num_protos]
|
out = self.compute_distances(x) # [None, num_protos]
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.labels
|
||||||
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
|
winning = stratified_min_pooling(out, plabels) # [None, num_classes]
|
||||||
probs = -1.0 * winning
|
probs = -1.0 * winning
|
||||||
batch_loss = self.loss(probs, y.long())
|
batch_loss = self.loss(probs, y.long())
|
||||||
|
Loading…
Reference in New Issue
Block a user