[REFACTOR] neighbour -> neighbor

This commit is contained in:
Jensun Ravichandran 2021-06-02 00:29:45 +02:00
parent 9eb6476078
commit 91b57b01b1

View File

@ -38,18 +38,18 @@ class GNGCallback(Callback):
# Insertion point
worst = torch.argmax(errors)
neighbours = topology.get_neighbours(worst)[0]
neighbors = topology.get_neighbors(worst)[0]
if len(neighbours) == 0:
logging.log("No Neighbour pair found")
if len(neighbors) == 0:
logging.log(level=20, msg="No neighbor-pairs found!")
return
neighbours_errors = errors[neighbours]
worst_neighbour = neighbours[torch.argmax(neighbours_errors)]
neighbors_errors = errors[neighbors]
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
# New Prototype
new_component = 0.5 * (components[worst] +
components[worst_neighbour])
components[worst_neighbor])
# Add component
pl_module.proto_layer.add_components(
@ -58,15 +58,15 @@ class GNGCallback(Callback):
# Adjust Topology
topology.add_prototype()
topology.add_connection(worst, -1)
topology.add_connection(worst_neighbour, -1)
topology.remove_connection(worst, worst_neighbour)
topology.add_connection(worst_neighbor, -1)
topology.remove_connection(worst, worst_neighbor)
# New errors
worst_error = errors[worst].unsqueeze(0)
pl_module.errors = torch.cat([pl_module.errors, worst_error])
pl_module.errors[worst] = errors[worst] * self.reduction
pl_module.errors[
worst_neighbour] = errors[worst_neighbour] * self.reduction
worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator_backend.setup_optimizers(trainer)
@ -98,7 +98,7 @@ class ConnectionTopology(torch.nn.Module):
self.cmat[i0][self.age[i0] > self.agelimit] = 0
self.cmat[i1][self.age[i1] > self.agelimit] = 0
def get_neighbours(self, position):
def get_neighbors(self, position):
return torch.where(self.cmat[position])
def add_prototype(self):