diff --git a/prototorch/models/cbc.py b/prototorch/models/cbc.py index 2934c56..64b48c1 100644 --- a/prototorch/models/cbc.py +++ b/prototorch/models/cbc.py @@ -153,4 +153,4 @@ class ImageCBC(CBC): """ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): # super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) - self.component_layer.prototypes.data.clamp_(0.0, 1.0) + self.component_layer.components.data.clamp_(0.0, 1.0)