diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index d80e2d9..d08dbec 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -96,7 +96,7 @@ class UnsupervisedPrototypeModel(PrototypeModel): ) def compute_distances(self, x): - protos = self.proto_layer() + protos = self.proto_layer().type_as(x) distances = self.distance_layer(x, protos) return distances diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index 62b628a..ed3a141 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -134,4 +134,4 @@ class GNGCallback(pl.Callback): pl_module.errors[ worst_neighbor] = errors[worst_neighbor] * self.reduction - trainer.accelerator_backend.setup_optimizers(trainer) + trainer.accelerator.setup_optimizers(trainer) diff --git a/prototorch/models/lvq.py b/prototorch/models/lvq.py index a79a2d7..b06a8dc 100644 --- a/prototorch/models/lvq.py +++ b/prototorch/models/lvq.py @@ -96,8 +96,7 @@ class MedianLVQ(NonGradientMixin, GLVQ): return lower_bound def training_step(self, train_batch, batch_idx, optimizer_idx=None): - protos = self.proto_layer.components - plabels = self.proto_layer.labels + protos, plabels = self.proto_layer() x, y = train_batch dis = self.compute_distances(x) diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index 9740b65..c18f033 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -53,7 +53,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel): grid = self._grid.view(-1, 2) gd = squared_euclidean_distance(wp, grid) nh = torch.exp(-gd / self._sigma**2) - protos = self.proto_layer.components + protos = self.proto_layer() diff = x.unsqueeze(dim=1) - protos delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff updated_protos = protos + delta.sum(dim=0) diff --git a/tests/test_examples.sh b/tests/test_examples.sh index f772071..68f105b 100755 --- a/tests/test_examples.sh +++ b/tests/test_examples.sh @@ -5,7 +5,7 @@ failed=0 for example in $(find $1 -maxdepth 1 -name "*.py") do echo -n "$x" $example '... ' - export DISPLAY= && python $example --fast_dev_run 1 &> run_log.txt + export DISPLAY= && python $example --fast_dev_run 1 --gpus 0 &> run_log.txt if [[ $? -ne 0 ]]; then echo "FAILED!!" cat run_log.txt