[REFACTOR] neighbour
-> neighbor
This commit is contained in:
parent
9eb6476078
commit
91b57b01b1
@ -38,18 +38,18 @@ class GNGCallback(Callback):
|
|||||||
# Insertion point
|
# Insertion point
|
||||||
worst = torch.argmax(errors)
|
worst = torch.argmax(errors)
|
||||||
|
|
||||||
neighbours = topology.get_neighbours(worst)[0]
|
neighbors = topology.get_neighbors(worst)[0]
|
||||||
|
|
||||||
if len(neighbours) == 0:
|
if len(neighbors) == 0:
|
||||||
logging.log("No Neighbour pair found")
|
logging.log(level=20, msg="No neighbor-pairs found!")
|
||||||
return
|
return
|
||||||
|
|
||||||
neighbours_errors = errors[neighbours]
|
neighbors_errors = errors[neighbors]
|
||||||
worst_neighbour = neighbours[torch.argmax(neighbours_errors)]
|
worst_neighbor = neighbors[torch.argmax(neighbors_errors)]
|
||||||
|
|
||||||
# New Prototype
|
# New Prototype
|
||||||
new_component = 0.5 * (components[worst] +
|
new_component = 0.5 * (components[worst] +
|
||||||
components[worst_neighbour])
|
components[worst_neighbor])
|
||||||
|
|
||||||
# Add component
|
# Add component
|
||||||
pl_module.proto_layer.add_components(
|
pl_module.proto_layer.add_components(
|
||||||
@ -58,15 +58,15 @@ class GNGCallback(Callback):
|
|||||||
# Adjust Topology
|
# Adjust Topology
|
||||||
topology.add_prototype()
|
topology.add_prototype()
|
||||||
topology.add_connection(worst, -1)
|
topology.add_connection(worst, -1)
|
||||||
topology.add_connection(worst_neighbour, -1)
|
topology.add_connection(worst_neighbor, -1)
|
||||||
topology.remove_connection(worst, worst_neighbour)
|
topology.remove_connection(worst, worst_neighbor)
|
||||||
|
|
||||||
# New errors
|
# New errors
|
||||||
worst_error = errors[worst].unsqueeze(0)
|
worst_error = errors[worst].unsqueeze(0)
|
||||||
pl_module.errors = torch.cat([pl_module.errors, worst_error])
|
pl_module.errors = torch.cat([pl_module.errors, worst_error])
|
||||||
pl_module.errors[worst] = errors[worst] * self.reduction
|
pl_module.errors[worst] = errors[worst] * self.reduction
|
||||||
pl_module.errors[
|
pl_module.errors[
|
||||||
worst_neighbour] = errors[worst_neighbour] * self.reduction
|
worst_neighbor] = errors[worst_neighbor] * self.reduction
|
||||||
|
|
||||||
trainer.accelerator_backend.setup_optimizers(trainer)
|
trainer.accelerator_backend.setup_optimizers(trainer)
|
||||||
|
|
||||||
@ -98,7 +98,7 @@ class ConnectionTopology(torch.nn.Module):
|
|||||||
self.cmat[i0][self.age[i0] > self.agelimit] = 0
|
self.cmat[i0][self.age[i0] > self.agelimit] = 0
|
||||||
self.cmat[i1][self.age[i1] > self.agelimit] = 0
|
self.cmat[i1][self.age[i1] > self.agelimit] = 0
|
||||||
|
|
||||||
def get_neighbours(self, position):
|
def get_neighbors(self, position):
|
||||||
return torch.where(self.cmat[position])
|
return torch.where(self.cmat[position])
|
||||||
|
|
||||||
def add_prototype(self):
|
def add_prototype(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user