From 889b0dfa779b9a605d3b40b88f38244b78d7d128 Mon Sep 17 00:00:00 2001 From: Michael Peters Date: Mon, 26 Aug 2024 14:09:48 -0700 Subject: [PATCH] create new connection script --- src/site/snake/brain-neat.ts | 70 +++++++++++++++++++++++++---------- src/site/snake/game-engine.ts | 1 + src/site/snake/util.ts | 4 ++ 3 files changed, 56 insertions(+), 19 deletions(-) diff --git a/src/site/snake/brain-neat.ts b/src/site/snake/brain-neat.ts index acfd477..544bfba 100644 --- a/src/site/snake/brain-neat.ts +++ b/src/site/snake/brain-neat.ts @@ -119,7 +119,7 @@ * - use random values connection weights * - all nodes should be marked "enabled" * - all organisms in the initial population will be part of species #1 - * - this implies that randomized connection weights should be chosen such that c3*W < δ_t for all organisms + * - this implies that randomized connection weights should be chosen such that c3*W < δ_t for all organisms * 3. Training Loop * a) Compute fitness f_i for all organisms[i] * b) Find adjusted f_adj_i based on each organism's species @@ -129,7 +129,7 @@ * - All organisms whose fitness passes the survival threshold (based on f_adj_i) are considered survivors * - Add all survivors to the next generation * - Fertility: fertility[i] = # of offspring from organism i as the mom - * - Organisms from species that have not improved in fitness for stag_lim generations are given fertility[i] = 0 (barred from mating) + * - Organisms from species that have not improved in fitness for stag_lim generations are given fertility[i] = 0 (barred from mating) * - Count the number of fertile survivors * - Compute general, remainder = divmod((pop - # survivors), # fertile survivors) * - Select remainder organisms to be given one bonus fertility (at random or by fitness) @@ -228,8 +228,8 @@ * - Effectively pass data from inputs, through hidden nodes, to outputs */ -import { edgesToNodes, NodeID, Network, Node, topoSort } from './network'; -import { keyMax, randint, randomNegPos } from './util'; +import { edgesToNodes, NodeID, Network, Node, topoSort, traceParents, RawEdge } from './network'; +import { keyMax, randchoice, randint, randomNegPos, setDifference, setMap } from './util'; interface GeneData { innovation: number; @@ -375,15 +375,6 @@ export function chooseSurvivors(population: Population, fitness: Map { - const nodeIDs = new Set(); - for (const gene of genome) { - nodeIDs.add(gene.src_id); - nodeIDs.add(gene.dst_id); - } - return nodeIDs; -} - interface MutateConfig { mutate_rate: number; // chance to mutate a gene's weight assign_rate: number; // chance to assign instead of uniformly perturb @@ -440,13 +431,53 @@ export function mutate(genome: Genome, config: MutateConfig): Genome { } if (Math.random() < new_connection_rate) { - // create a new connection - // between *previously unconnected* nodes - // TODO: use traceParents - // TODO: test traceParents + // create a new connection between two *previously unconnected* nodes + // NOTE: there's some performance stuff here that could definitely be improved + const nodes = edgesToNodes(newGenome); - const sources = Array.from(nodes.values()).filter(n => n.srcs.size === 0); - const sinks = Array.from(nodes.values()).filter(n => n.dsts.size === 0); + + // find nodes that can be connected without creating a cycle + // a node that is connected to one of its parents creates a cycle + const allNodeIDs = new Set(nodes.keys()); + const parents = traceParents(nodes); + const acyclic = new Map>(); + for (const [nodeID, nodeParents] of parents.entries()) { + const nodeParentIDs = setMap(nodeParents, n => n.id); + const nodeAcyclic = setDifference(allNodeIDs, nodeParentIDs); + acyclic.set(nodeID, nodeAcyclic); + } + + // flatten + const acyclicConns: { src_id: NodeID, dst_id: NodeID }[] = []; + for (const [nodeID, nodeAcyclic] of acyclic) { + acyclicConns.push(...setMap(nodeAcyclic, dst => ({ src_id: nodeID, dst_id: dst }))); + } + + // remove options that are already connected + const options: { src_id: NodeID, dst_id: NodeID }[] = []; + for (const conn of acyclicConns) { + if (newGenome.findIndex(c => c.src_id === conn.src_id && c.dst_id === conn.dst_id) === -1) continue; + options.push(conn); + } + + // choose a random connection + if (options.length === 0) { + // TODO: remove this warn once this starts working + // this is mostly a sanity check / useful for metrics + console.warn("could not find a valid new connection!"); + } else { + const { src_id, dst_id } = randchoice(options); + const newGene = { + src_id, + dst_id, + data: { + innovation: g_innovation_number++, + weight: assign_mag * randomNegPos(), + enabled: true, + } + } + newGenome.push(newGene); + } } return newGenome; @@ -476,6 +507,7 @@ export class Organism { /** given per-node input activations, computes downstream activations in-place.*/ think(activations: Map) { + // TODO: do not follow disabled connections for (const id of this.order) { const node = this.network.get(id)!; if (node.srcs.size === 0) { diff --git a/src/site/snake/game-engine.ts b/src/site/snake/game-engine.ts index 29d8185..12eaf62 100644 --- a/src/site/snake/game-engine.ts +++ b/src/site/snake/game-engine.ts @@ -7,6 +7,7 @@ export function randint(low: number, high: number) { return Math.floor(Math.random() * range) + low; } +/** deprecated, use util.ts instead */ export function randchoice(arr: T[]) { return arr[randint(0, arr.length - 1)]!; } diff --git a/src/site/snake/util.ts b/src/site/snake/util.ts index 64ccc63..57b6d15 100644 --- a/src/site/snake/util.ts +++ b/src/site/snake/util.ts @@ -9,6 +9,10 @@ export function randint(low: number, high: number) { return Math.floor(Math.random() * range) + low; } +export function randchoice(arr: T[]) { + return arr[randint(0, arr.length - 1)]!; +} + export function keyMax(values: Iterable, keyFunc: (v: T) => number) { let best = null; for (const value of values) {