[REFACTOR] neighbour -> neighbor

This commit is contained in:
Jensun Ravichandran 2021-06-02 00:29:45 +02:00
parent 9eb6476078
commit 91b57b01b1

View File

@ -38,18 +38,18 @@ class GNGCallback(Callback):
# Insertion point # Insertion point
worst = torch.argmax(errors) worst = torch.argmax(errors)
neighbours = topology.get_neighbours(worst)[0] neighbors = topology.get_neighbors(worst)[0]
if len(neighbours) == 0: if len(neighbors) == 0:
logging.log("No Neighbour pair found") logging.log(level=20, msg="No neighbor-pairs found!")
return return
neighbours_errors = errors[neighbours] neighbors_errors = errors[neighbors]
worst_neighbour = neighbours[torch.argmax(neighbours_errors)] worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
# New Prototype # New Prototype
new_component = 0.5 * (components[worst] + new_component = 0.5 * (components[worst] +
components[worst_neighbour]) components[worst_neighbor])
# Add component # Add component
pl_module.proto_layer.add_components( pl_module.proto_layer.add_components(
@ -58,15 +58,15 @@ class GNGCallback(Callback):
# Adjust Topology # Adjust Topology
topology.add_prototype() topology.add_prototype()
topology.add_connection(worst, -1) topology.add_connection(worst, -1)
topology.add_connection(worst_neighbour, -1) topology.add_connection(worst_neighbor, -1)
topology.remove_connection(worst, worst_neighbour) topology.remove_connection(worst, worst_neighbor)
# New errors # New errors
worst_error = errors[worst].unsqueeze(0) worst_error = errors[worst].unsqueeze(0)
pl_module.errors = torch.cat([pl_module.errors, worst_error]) pl_module.errors = torch.cat([pl_module.errors, worst_error])
pl_module.errors[worst] = errors[worst] * self.reduction pl_module.errors[worst] = errors[worst] * self.reduction
pl_module.errors[ pl_module.errors[
worst_neighbour] = errors[worst_neighbour] * self.reduction worst_neighbor] = errors[worst_neighbor] * self.reduction
trainer.accelerator_backend.setup_optimizers(trainer) trainer.accelerator_backend.setup_optimizers(trainer)
@ -98,7 +98,7 @@ class ConnectionTopology(torch.nn.Module):
self.cmat[i0][self.age[i0] > self.agelimit] = 0 self.cmat[i0][self.age[i0] > self.agelimit] = 0
self.cmat[i1][self.age[i1] > self.agelimit] = 0 self.cmat[i1][self.age[i1] > self.agelimit] = 0
def get_neighbours(self, position): def get_neighbors(self, position):
return torch.where(self.cmat[position]) return torch.where(self.cmat[position])
def add_prototype(self): def add_prototype(self):