diff --git a/src/site/snake/neat.ts b/src/site/snake/neat.ts index e3b37a2..70391b2 100644 --- a/src/site/snake/neat.ts +++ b/src/site/snake/neat.ts @@ -229,7 +229,7 @@ */ import { edgesToNodes, NodeID, traceParents, RawEdge, traceChildren } from './network'; -import { keyMax, mapInvert, mapMap, randchoice, randint, randomNegPos, setDifference, setMap } from './util'; +import { keyMax, mapInvert, mapMap, randchoice, randint, randomNegPos, setDifference, setMap, setUnion } from './util'; export interface GeneData { innovation: number; @@ -548,19 +548,18 @@ export function findAcyclicInternalNewConns(rawEdges: RawEdge[]): const parents = traceParents(nodes); const children = traceChildren(nodes); - // TODO: exclude *both* sources and sinks - // - sources are defined during think - // - sinks should not be connected? <-- they probably could be, but not clean imo + // exclude *both* sources and sinks + // - sources are defined during think, connections will be ignored + // - sinks should not be connected? <-- they probably could be, but it would be unclean imo const sources = new Set(Array.from(parents.entries()).flatMap(([k, v]) => (v.size === 0 ? [k] : []))); const sinks = new Set(Array.from(children.entries()).flatMap(([k, v]) => (v.size === 0 ? [k] : []))); const acyclic = new Map>(); for (const [nodeID, nodeParents] of parents.entries()) { const nodeParentIDs = setMap(nodeParents, n => n.id); - const nodeIDsAcyclic = setDifference(allNodeIDs, nodeParentIDs); - nodeIDsAcyclic.delete(nodeID); - const nodeIDsAcyclicNonSource = setDifference(nodeIDsAcyclic, sources); - acyclic.set(nodeID, nodeIDsAcyclicNonSource); + const exclude = setUnion(nodeParentIDs, sources, sinks, new Set([nodeID])); + const nodeIDsAcyclic = setDifference(allNodeIDs, exclude); + acyclic.set(nodeID, nodeIDsAcyclic); } // flatten diff --git a/src/site/snake/util.ts b/src/site/snake/util.ts index 75d83d3..506dd2c 100644 --- a/src/site/snake/util.ts +++ b/src/site/snake/util.ts @@ -46,7 +46,7 @@ export function mapMap(m: Map, mapper: (k: K1, v: V1) => return new Map(Array.from(m.entries()).map(([k, v]) => mapper(k, v))); } -export function setIntersection(...sets: Set[]) { +export function setIntersection(...sets: Set[]): Set { // a & b & ... if (sets.length === 0) return new Set(); const intersection = new Set(sets[0]!); @@ -56,7 +56,7 @@ export function setIntersection(...sets: Set[]) { return intersection; } -export function setUnion(...sets: Set[]) { +export function setUnion(...sets: Set[]): Set { // a & b & ... if (sets.length === 0) return new Set(); const union = new Set(sets[0]!); diff --git a/src/test/test-neat.ts b/src/test/test-neat.ts index 2e2d4bc..bedc89f 100644 --- a/src/test/test-neat.ts +++ b/src/test/test-neat.ts @@ -380,21 +380,21 @@ function testFindAcyclicInternalNewConns() { const options = findAcyclicInternalNewConns(edges); const expected = [ // not A -> B - B is a source - { src_id: 'A', dst_id: 'F' }, + // not A -> F - F is a sink // not B -> A - A is a source { src_id: 'B', dst_id: 'C' }, - { src_id: 'B', dst_id: 'E' }, - { src_id: 'B', dst_id: 'F' }, + // not B -> E - E is a sink + // not B -> F - F is a sink // not C -> B - B is a source { src_id: 'C', dst_id: 'D' }, - { src_id: 'C', dst_id: 'F' }, + // not C -> F - F is a sink { src_id: 'D', dst_id: 'C' }, // not E -> B - B is a parent of E - { src_id: 'E', dst_id: 'F' }, + // not E -> F - F is a sink // not F -> A - A is a parent of F // not F -> B - B is a parent of F { src_id: 'F', dst_id: 'C' }, - { src_id: 'F', dst_id: 'E' }, + // not F -> E - E is a sink ]; function strcmp(a: string, b: string) {