From 91b57b01b1d99b16008427db3d7d26d619c71343 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 2 Jun 2021 00:29:45 +0200 Subject: [PATCH] [REFACTOR] `neighbour` -> `neighbor` --- prototorch/models/unsupervised.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/prototorch/models/unsupervised.py b/prototorch/models/unsupervised.py index 8be6e37..235a1a9 100644 --- a/prototorch/models/unsupervised.py +++ b/prototorch/models/unsupervised.py @@ -38,18 +38,18 @@ class GNGCallback(Callback): # Insertion point worst = torch.argmax(errors) - neighbours = topology.get_neighbours(worst)[0] + neighbors = topology.get_neighbors(worst)[0] - if len(neighbours) == 0: - logging.log("No Neighbour pair found") + if len(neighbors) == 0: + logging.log(level=20, msg="No neighbor-pairs found!") return - neighbours_errors = errors[neighbours] - worst_neighbour = neighbours[torch.argmax(neighbours_errors)] + neighbors_errors = errors[neighbors] + worst_neighbor = neighbors[torch.argmax(neighbors_errors)] # New Prototype new_component = 0.5 * (components[worst] + - components[worst_neighbour]) + components[worst_neighbor]) # Add component pl_module.proto_layer.add_components( @@ -58,15 +58,15 @@ class GNGCallback(Callback): # Adjust Topology topology.add_prototype() topology.add_connection(worst, -1) - topology.add_connection(worst_neighbour, -1) - topology.remove_connection(worst, worst_neighbour) + topology.add_connection(worst_neighbor, -1) + topology.remove_connection(worst, worst_neighbor) # New errors worst_error = errors[worst].unsqueeze(0) pl_module.errors = torch.cat([pl_module.errors, worst_error]) pl_module.errors[worst] = errors[worst] * self.reduction pl_module.errors[ - worst_neighbour] = errors[worst_neighbour] * self.reduction + worst_neighbor] = errors[worst_neighbor] * self.reduction trainer.accelerator_backend.setup_optimizers(trainer) @@ -98,7 +98,7 @@ class ConnectionTopology(torch.nn.Module): self.cmat[i0][self.age[i0] > self.agelimit] = 0 self.cmat[i1][self.age[i1] > self.agelimit] = 0 - def get_neighbours(self, position): + def get_neighbors(self, position): return torch.where(self.cmat[position]) def add_prototype(self):