From a0a2e7c9bbfed9499b580f651941d286225092aa Mon Sep 17 00:00:00 2001 From: Michael Peters Date: Wed, 4 Sep 2024 20:27:50 -0700 Subject: [PATCH] add setUnion and make setIntersection more versatile --- src/site/snake/canvas.ts | 2 +- src/site/snake/neat.ts | 15 +++++++------- src/site/snake/util.ts | 43 +++++++++++++++++++++++++--------------- src/test/test-neat.ts | 8 ++++---- src/test/test-util.ts | 20 ++++++++++++++++++- 5 files changed, 58 insertions(+), 30 deletions(-) diff --git a/src/site/snake/canvas.ts b/src/site/snake/canvas.ts index 29707e7..4726df9 100644 --- a/src/site/snake/canvas.ts +++ b/src/site/snake/canvas.ts @@ -343,5 +343,5 @@ export default function runCanvas(canvas: HTMLCanvasElement, pipeRef: MutableRef } keys.bindKeys(); - engine.run(update, render); + // engine.run(update, render); } diff --git a/src/site/snake/neat.ts b/src/site/snake/neat.ts index c4e7751..e3b37a2 100644 --- a/src/site/snake/neat.ts +++ b/src/site/snake/neat.ts @@ -536,7 +536,7 @@ export function mutateNewConn(newConn: { src_id: NodeID; dst_id: NodeID }, newWe } // TODO: improve name since also excluding source nodes -export function findAcyclicNonSourceNewConns(rawEdges: RawEdge[]): { src_id: NodeID; dst_id: NodeID }[] { +export function findAcyclicInternalNewConns(rawEdges: RawEdge[]): { src_id: NodeID; dst_id: NodeID }[] { // finds potential new connections that are acyclic and are not already connected // NOTE: there's some performance stuff here that could definitely be improved @@ -548,12 +548,11 @@ export function findAcyclicNonSourceNewConns(rawEdges: RawEdge[]): const parents = traceParents(nodes); const children = traceChildren(nodes); - // TODO: think on this more: - // adding a new connection needs to not connect a source -> source or a sink -> sink - // or else the innovation number system gets messed up + // TODO: exclude *both* sources and sinks + // - sources are defined during think + // - sinks should not be connected? <-- they probably could be, but not clean imo const sources = new Set(Array.from(parents.entries()).flatMap(([k, v]) => (v.size === 0 ? [k] : []))); const sinks = new Set(Array.from(children.entries()).flatMap(([k, v]) => (v.size === 0 ? [k] : []))); - console.log({ sources, sinks }); const acyclic = new Map>(); for (const [nodeID, nodeParents] of parents.entries()) { @@ -591,7 +590,7 @@ export function mutate(genome: Genome, config: MutateConfig): Genome { const { mutate_rate, assign_rate, assign_mag, perturb_mag, new_node_rate, new_connection_rate } = config; console.log('mutating: ', hashGenome(genome), genome); - findAcyclicNonSourceNewConns(genome); + findAcyclicInternalNewConns(genome); const newGenome = genome.map(gene => { if (Math.random() >= mutate_rate) return gene; // this connection should not be mutated @@ -618,7 +617,7 @@ export function mutate(genome: Genome, config: MutateConfig): Genome { // TODO: figure out why we're not hitting a // TODO: disallow "source" nodes from being new connection destinations // create a new connection between two *previously unconnected* nodes - const options = findAcyclicNonSourceNewConns(newGenome); + const options = findAcyclicInternalNewConns(newGenome); if (options.length === 0) { // TODO: remove this warn once this starts working // this is mostly a sanity check / useful for metrics @@ -634,7 +633,7 @@ export function mutate(genome: Genome, config: MutateConfig): Genome { newGenome.push(newGene); } - findAcyclicNonSourceNewConns(newGenome); + findAcyclicInternalNewConns(newGenome); } console.log('created: ', hashGenome(newGenome), newGenome); diff --git a/src/site/snake/util.ts b/src/site/snake/util.ts index 2bb03dd..75d83d3 100644 --- a/src/site/snake/util.ts +++ b/src/site/snake/util.ts @@ -27,34 +27,45 @@ export function keyMax(values: Iterable, keyFunc: (v: T) => number) { return best.value; } -// TODO: add test for mapInvert export function mapInvert(m: Map): Map> { - const newM = new Map>(); - for (const [k, v] of m.entries()) { - if (newM.has(v)) { - newM.get(v)!.add(k); - } else { - newM.set(v, new Set([k])); - } - } - return newM; + const newM = new Map>(); + for (const [k, v] of m.entries()) { + if (newM.has(v)) { + newM.get(v)!.add(k); + } else { + newM.set(v, new Set([k])); + } + } + return newM; } export function setMap(s: Set, mapper: (t: T) => U): Set { return new Set(Array.from(s).map(mapper)); } -// TODO: add test for mapMap export function mapMap(m: Map, mapper: (k: K1, v: V1) => [k: K2, v: V2]): Map { - return new Map(Array.from(m.entries()).map(([k, v]) => mapper(k, v))); + return new Map(Array.from(m.entries()).map(([k, v]) => mapper(k, v))); } -export function setIntersection(a: Set, b: Set) { - // a - b - const intersection = new Set(a); - for (const e of a) if (!b.has(e)) intersection.delete(e); +export function setIntersection(...sets: Set[]) { + // a & b & ... + if (sets.length === 0) return new Set(); + const intersection = new Set(sets[0]!); + for (const set of sets.slice(1)) { + for (const e of intersection) if (!set.has(e)) intersection.delete(e); + } return intersection; } +export function setUnion(...sets: Set[]) { + // a & b & ... + if (sets.length === 0) return new Set(); + const union = new Set(sets[0]!); + for (const set of sets.slice(1)) { + for (const e of set) union.add(e); + } + return union; +} + export function setDifference(a: Set, b: Set) { // a - b const diff = new Set(a); diff --git a/src/test/test-neat.ts b/src/test/test-neat.ts index e3d4dae..2e2d4bc 100644 --- a/src/test/test-neat.ts +++ b/src/test/test-neat.ts @@ -4,7 +4,7 @@ import { tournamentSelectionWithChampions, compatibilityDistance, crossGenomes, - findAcyclicNonSourceNewConns, + findAcyclicInternalNewConns, Genome, mutateAssign, mutateNewConn, @@ -357,7 +357,7 @@ function testMutateNewConn() { } addTest(testMutateNewConn); -function testFindAcyclicNonSourceNewConns() { +function testFindAcyclicInternalNewConns() { /* * all edges pointing down * @@ -377,7 +377,7 @@ function testFindAcyclicNonSourceNewConns() { { src_id: 'D', dst_id: 'F', data: null }, ]; - const options = findAcyclicNonSourceNewConns(edges); + const options = findAcyclicInternalNewConns(edges); const expected = [ // not A -> B - B is a source { src_id: 'A', dst_id: 'F' }, @@ -406,7 +406,7 @@ function testFindAcyclicNonSourceNewConns() { assertDeepEqual(options, expected); } -addTest(testFindAcyclicNonSourceNewConns); +addTest(testFindAcyclicInternalNewConns); function testComputeNextGeneration() { function makeGenome(weight: number) { diff --git a/src/test/test-util.ts b/src/test/test-util.ts index 0b3bc3e..c33e54d 100644 --- a/src/test/test-util.ts +++ b/src/test/test-util.ts @@ -1,4 +1,4 @@ -import { keyMax, mapInvert, mapMap, setDifference, setMap } from '../site/snake/util'; +import { keyMax, mapInvert, mapMap, setDifference, setIntersection, setMap, setUnion } from '../site/snake/util'; import { addTest, assert, assertSetEqual } from './tests'; function testKeyMax() { @@ -34,6 +34,24 @@ function testSetMap() { } addTest(testSetMap); +function testSetIntersection() { + const a = new Set([1, 2, 3]); + const b = new Set([2, 3, 4]); + const expected = new Set([2, 3]); + const actual = setIntersection(a, b); + assertSetEqual(actual, expected); +} +addTest(testSetIntersection); + +function testSetUnion() { + const a = new Set([1, 2, 3]); + const b = new Set([2, 3, 4]); + const expected = new Set([1, 2, 3, 4]); + const actual = setUnion(a, b); + assertSetEqual(actual, expected); +} +addTest(testSetUnion); + function testMapMap() { const src = new Map(); src.set('A', 1);