fix: All examples should work on CPU and GPU now

This commit is contained in:
Alexander Engelsberger 2021-08-05 11:20:02 +02:00
parent 0af8cf36f8
commit d7834e2cc0
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
5 changed files with 5 additions and 6 deletions

View File

@ -96,7 +96,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
)
def compute_distances(self, x):
protos = self.proto_layer()
protos = self.proto_layer().type_as(x)
distances = self.distance_layer(x, protos)
return distances

View File

@ -134,4 +134,4 @@ class GNGCallback(pl.Callback):
pl_module.errors[
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator_backend.setup_optimizers(trainer)
trainer.accelerator.setup_optimizers(trainer)

View File

@ -96,8 +96,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
return lower_bound
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.labels
protos, plabels = self.proto_layer()
x, y = train_batch
dis = self.compute_distances(x)

View File

@ -53,7 +53,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
grid = self._grid.view(-1, 2)
gd = squared_euclidean_distance(wp, grid)
nh = torch.exp(-gd / self._sigma**2)
protos = self.proto_layer.components
protos = self.proto_layer()
diff = x.unsqueeze(dim=1) - protos
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
updated_protos = protos + delta.sum(dim=0)

View File

@ -5,7 +5,7 @@ failed=0
for example in $(find $1 -maxdepth 1 -name "*.py")
do
echo -n "$x" $example '... '
export DISPLAY= && python $example --fast_dev_run 1 &> run_log.txt
export DISPLAY= && python $example --fast_dev_run 1 --gpus 0 &> run_log.txt
if [[ $? -ne 0 ]]; then
echo "FAILED!!"
cat run_log.txt