fix: All examples should work on CPU and GPU now
This commit is contained in:
parent
0af8cf36f8
commit
d7834e2cc0
@ -96,7 +96,7 @@ class UnsupervisedPrototypeModel(PrototypeModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos = self.proto_layer()
|
protos = self.proto_layer().type_as(x)
|
||||||
distances = self.distance_layer(x, protos)
|
distances = self.distance_layer(x, protos)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
@ -134,4 +134,4 @@ class GNGCallback(pl.Callback):
|
|||||||
pl_module.errors[
|
pl_module.errors[
|
||||||
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
||||||
|
|
||||||
trainer.accelerator_backend.setup_optimizers(trainer)
|
trainer.accelerator.setup_optimizers(trainer)
|
||||||
|
@ -96,8 +96,7 @@ class MedianLVQ(NonGradientMixin, GLVQ):
|
|||||||
return lower_bound
|
return lower_bound
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
protos = self.proto_layer.components
|
protos, plabels = self.proto_layer()
|
||||||
plabels = self.proto_layer.labels
|
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self.compute_distances(x)
|
dis = self.compute_distances(x)
|
||||||
|
@ -53,7 +53,7 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
|||||||
grid = self._grid.view(-1, 2)
|
grid = self._grid.view(-1, 2)
|
||||||
gd = squared_euclidean_distance(wp, grid)
|
gd = squared_euclidean_distance(wp, grid)
|
||||||
nh = torch.exp(-gd / self._sigma**2)
|
nh = torch.exp(-gd / self._sigma**2)
|
||||||
protos = self.proto_layer.components
|
protos = self.proto_layer()
|
||||||
diff = x.unsqueeze(dim=1) - protos
|
diff = x.unsqueeze(dim=1) - protos
|
||||||
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
||||||
updated_protos = protos + delta.sum(dim=0)
|
updated_protos = protos + delta.sum(dim=0)
|
||||||
|
@ -5,7 +5,7 @@ failed=0
|
|||||||
for example in $(find $1 -maxdepth 1 -name "*.py")
|
for example in $(find $1 -maxdepth 1 -name "*.py")
|
||||||
do
|
do
|
||||||
echo -n "$x" $example '... '
|
echo -n "$x" $example '... '
|
||||||
export DISPLAY= && python $example --fast_dev_run 1 &> run_log.txt
|
export DISPLAY= && python $example --fast_dev_run 1 --gpus 0 &> run_log.txt
|
||||||
if [[ $? -ne 0 ]]; then
|
if [[ $? -ne 0 ]]; then
|
||||||
echo "FAILED!!"
|
echo "FAILED!!"
|
||||||
cat run_log.txt
|
cat run_log.txt
|
||||||
|
Loading…
Reference in New Issue
Block a user