fix crossGenomes to prevent cycles

This commit is contained in:
Michael Peters 2024-09-04 21:00:07 -07:00
parent a5884a68db
commit b33ffc3baa
4 changed files with 34 additions and 73 deletions

View File

@ -58,7 +58,7 @@ export interface TrainerSnapshot {
// labs and mutation ----------------------------------------------------------- // labs and mutation -----------------------------------------------------------
// TODO: random colors for each species // TODO: colors based on species number
function makeLabColors({ hue, sat, lig }: { hue: number; sat: number; lig: number }): LabColors { function makeLabColors({ hue, sat, lig }: { hue: number; sat: number; lig: number }): LabColors {
const head = `hsl(${hue},${sat}%,${lig}%)`; const head = `hsl(${hue},${sat}%,${lig}%)`;
@ -343,5 +343,5 @@ export default function runCanvas(canvas: HTMLCanvasElement, pipeRef: MutableRef
} }
keys.bindKeys(); keys.bindKeys();
// engine.run(update, render); engine.run(update, render);
} }

View File

@ -141,7 +141,7 @@
* - Mating: * - Mating:
* - for each survivor, repeat fertility[i] times: * - for each survivor, repeat fertility[i] times:
* - select a mate (dad): * - select a mate (dad):
* - r_asex chance it is the same organism (TODO: at least one mutation is required?) * - r_asex chance it is the same organism
* - else r_int_sp chance it is from a random other organism from any other species * - else r_int_sp chance it is from a random other organism from any other species
* - else it is from a random other organism from the same species * - else it is from a random other organism from the same species
* - compute new genome (baby) for mom x dad * - compute new genome (baby) for mom x dad
@ -214,18 +214,6 @@
* | 20% | Survival Threshold (your adjusted fitness must be in the top x% to survive to the next generation) * | 20% | Survival Threshold (your adjusted fitness must be in the top x% to survive to the next generation)
* +-------+------- * +-------+-------
* *
* --- TODO --------------------------------------------------------------------
*
* - Determine reproduction algo
* - Better understand Explicit Fitness Sharing
* - Better understand weights mutation parameter uses
*
* - Implement Paper!
*
* --- Implementation Complexities ---------------------------------------------
*
* - Convert from genome -> neural network topology
* - Effectively pass data from inputs, through hidden nodes, to outputs
*/ */
import { edgesToNodes, NodeID, traceParents, RawEdge, traceChildren } from './network'; import { edgesToNodes, NodeID, traceParents, RawEdge, traceChildren } from './network';
@ -453,7 +441,7 @@ export function crossGenomes(mom: Genome, dad: Genome, fitness: Map<Genome, numb
const crossed: Genome = []; const crossed: Genome = [];
// matching // matching - randomly selected weight & random chance to re-enable
for (const { mom: momGene, dad: dadGene } of alignment.matching) { for (const { mom: momGene, dad: dadGene } of alignment.matching) {
const newWeight = Math.random() < 0.5 ? momGene.data.weight : dadGene.data.weight; const newWeight = Math.random() < 0.5 ? momGene.data.weight : dadGene.data.weight;
const wasDisabled = !momGene.data.enabled || !dadGene.data.enabled; const wasDisabled = !momGene.data.enabled || !dadGene.data.enabled;
@ -462,25 +450,15 @@ export function crossGenomes(mom: Genome, dad: Genome, fitness: Map<Genome, numb
crossed.push(newGene); crossed.push(newGene);
} }
// disjoint // disjoint & excess - take from most fit parent
const momFitness = fitness.get(mom)!; const momFitness = fitness.get(mom)!;
const dadFitness = fitness.get(dad)!; const dadFitness = fitness.get(dad)!;
const mostFit = momFitness === dadFitness ? 'equal' : momFitness > dadFitness ? 'mom' : 'dad'; if (momFitness > dadFitness) {
for (const { mom: momGene, dad: dadGene } of alignment.disjoint) { crossed.push(...alignment.disjoint.flatMap(({ mom: momGene }) => (momGene === null ? [] : momGene)));
if (momGene === null) crossed.push(dadGene as Gene); crossed.push(...alignment.excess.flatMap(({ mom: momGene }) => (momGene === null ? [] : momGene)));
else if (dadGene === null) crossed.push(momGene); } else {
else if (mostFit === 'mom') crossed.push(momGene); crossed.push(...alignment.disjoint.flatMap(({ dad: dadGene }) => (dadGene === null ? [] : dadGene)));
else if (mostFit === 'dad') crossed.push(dadGene); crossed.push(...alignment.excess.flatMap(({ dad: dadGene }) => (dadGene === null ? [] : dadGene)));
// both are equally fit - select at random
else if (Math.random() < 0.5) crossed.push(momGene);
else crossed.push(dadGene);
}
// excess
for (const { mom: momGene, dad: dadGene } of alignment.excess) {
if (momGene === null) crossed.push(dadGene as Gene);
else if (dadGene === null) crossed.push(momGene as Gene);
else throw Error(`invalid excess alignment: alignment=${JSON.stringify(alignment)}`);
} }
return crossed; return crossed;
@ -535,7 +513,6 @@ export function mutateNewConn(newConn: { src_id: NodeID; dst_id: NodeID }, newWe
}; };
} }
// TODO: improve name since also excluding source nodes
export function findAcyclicInternalNewConns<DataT>(rawEdges: RawEdge<DataT>[]): { src_id: NodeID; dst_id: NodeID }[] { export function findAcyclicInternalNewConns<DataT>(rawEdges: RawEdge<DataT>[]): { src_id: NodeID; dst_id: NodeID }[] {
// finds potential new connections that are acyclic and are not already connected // finds potential new connections that are acyclic and are not already connected
@ -548,16 +525,19 @@ export function findAcyclicInternalNewConns<DataT>(rawEdges: RawEdge<DataT>[]):
const parents = traceParents(nodes); const parents = traceParents(nodes);
const children = traceChildren(nodes); const children = traceChildren(nodes);
// exclude *both* sources and sinks // exclude sources from being dst nodes *and* exclude sinks from being src nodes
// - sources are defined during think, connections will be ignored
// - sinks should not be connected? <-- they probably could be, but it would be unclean imo
const sources = new Set(Array.from(parents.entries()).flatMap(([k, v]) => (v.size === 0 ? [k] : []))); const sources = new Set(Array.from(parents.entries()).flatMap(([k, v]) => (v.size === 0 ? [k] : [])));
const sinks = new Set(Array.from(children.entries()).flatMap(([k, v]) => (v.size === 0 ? [k] : []))); const sinks = new Set(Array.from(children.entries()).flatMap(([k, v]) => (v.size === 0 ? [k] : [])));
const acyclic = new Map<NodeID, Set<NodeID>>(); const acyclic = new Map<NodeID, Set<NodeID>>();
for (const [nodeID, nodeParents] of parents.entries()) { for (const [nodeID, nodeParents] of parents.entries()) {
if (sinks.has(nodeID)) {
// exclude sinks from being src nodes
acyclic.set(nodeID, new Set());
continue;
}
const nodeParentIDs = setMap(nodeParents, n => n.id); const nodeParentIDs = setMap(nodeParents, n => n.id);
const exclude = setUnion(nodeParentIDs, sources, sinks, new Set([nodeID])); const exclude = setUnion(nodeParentIDs, sources, new Set([nodeID]));
const nodeIDsAcyclic = setDifference(allNodeIDs, exclude); const nodeIDsAcyclic = setDifference(allNodeIDs, exclude);
acyclic.set(nodeID, nodeIDsAcyclic); acyclic.set(nodeID, nodeIDsAcyclic);
} }
@ -588,7 +568,6 @@ export interface MutateConfig {
export function mutate(genome: Genome, config: MutateConfig): Genome { export function mutate(genome: Genome, config: MutateConfig): Genome {
const { mutate_rate, assign_rate, assign_mag, perturb_mag, new_node_rate, new_connection_rate } = config; const { mutate_rate, assign_rate, assign_mag, perturb_mag, new_node_rate, new_connection_rate } = config;
console.log('mutating: ', hashGenome(genome), genome);
findAcyclicInternalNewConns(genome); findAcyclicInternalNewConns(genome);
const newGenome = genome.map(gene => { const newGenome = genome.map(gene => {
@ -613,8 +592,6 @@ export function mutate(genome: Genome, config: MutateConfig): Genome {
} }
if (Math.random() < new_connection_rate) { if (Math.random() < new_connection_rate) {
// TODO: figure out why we're not hitting a
// TODO: disallow "source" nodes from being new connection destinations
// create a new connection between two *previously unconnected* nodes // create a new connection between two *previously unconnected* nodes
const options = findAcyclicInternalNewConns(newGenome); const options = findAcyclicInternalNewConns(newGenome);
if (options.length === 0) { if (options.length === 0) {
@ -624,17 +601,12 @@ export function mutate(genome: Genome, config: MutateConfig): Genome {
} else { } else {
// choose a random connection // choose a random connection
const newConn = randchoice(options); const newConn = randchoice(options);
console.log('adding new connection...', hashGenome(genome), genome, newConn);
if (new Set(['U', 'D', 'L', 'R']).has(newConn.dst_id)) {
debugger;
}
const newGene = mutateNewConn(newConn, assign_mag * randomNegPos()); const newGene = mutateNewConn(newConn, assign_mag * randomNegPos());
newGenome.push(newGene); newGenome.push(newGene);
} }
findAcyclicInternalNewConns(newGenome); findAcyclicInternalNewConns(newGenome);
} }
console.log('created: ', hashGenome(newGenome), newGenome);
return newGenome; return newGenome;
} }
@ -667,7 +639,6 @@ export function computeNextGeneration(
const mom = randchoice(winners); const mom = randchoice(winners);
const dad = chooseMate(mom, winnersPopulation, mcc); const dad = chooseMate(mom, winnersPopulation, mcc);
const crossed = crossGenomes(mom, dad, fitness, cc); const crossed = crossGenomes(mom, dad, fitness, cc);
// TODO: crossed has cycles!
const baby = mutate(crossed, mc); const baby = mutate(crossed, mc);
// assign to a species + add to next generation // assign to a species + add to next generation

View File

@ -78,9 +78,8 @@ export function traceParents<DataT>(nodes: Network<Node<DataT>>) {
const parents = new Map<NodeID, Set<Node<DataT>>>(); const parents = new Map<NodeID, Set<Node<DataT>>>();
function traceNodeParents(node: Node<DataT>, depth: number): Set<Node<DataT>> { function traceNodeParents(node: Node<DataT>, depth: number): Set<Node<DataT>> {
console.log('tracing: ', node.id, depth);
// TODO: there's a particularly nasty issue where maybe nodes are not just from their own genome?
if (depth > 20) { if (depth > 20) {
// TODO: remove this check, it mostly just finds cycles
debugger; debugger;
} }
if (parents.has(node.id)) return parents.get(node.id)!; if (parents.has(node.id)) return parents.get(node.id)!;

View File

@ -267,14 +267,14 @@ function testCrossGenomes() {
]; ];
const fitness = new Map<Genome, number>(); const fitness = new Map<Genome, number>();
fitness.set(genomeA, 2); fitness.set(genomeA, 1);
fitness.set(genomeB, 1); fitness.set(genomeB, 2);
const cc = { reenable_rate: 1 }; const cc = { reenable_rate: 1 };
const crossed = crossGenomes(genomeA, genomeB, fitness, cc); const crossed = crossGenomes(genomeA, genomeB, fitness, cc);
assert(crossed.length === 10); assert(crossed.length === 9);
// matching // matching
assert(crossed[0]!.src_id === 'A'); assert(crossed[0]!.src_id === 'A');
@ -297,22 +297,13 @@ function testCrossGenomes() {
assert(crossed[4]!.data.weight === 5 || crossed[4]!.data.weight === 3); assert(crossed[4]!.data.weight === 5 || crossed[4]!.data.weight === 3);
assert(crossed[4]!.data.enabled === true); // always re-enabled since reenable_rate = 1 assert(crossed[4]!.data.enabled === true); // always re-enabled since reenable_rate = 1
// disjoint // disjoint (all from genomeB)
assert(crossed[5]!.data.innovation === 6); assert(crossed[5]!.data.innovation === 6);
assert(crossed[5]!.data.weight === 5);
assert(crossed[6]!.data.innovation === 7); assert(crossed[6]!.data.innovation === 7);
assert(crossed[6]!.data.weight === 5);
assert(crossed[7]!.data.innovation === 8); // excess (all from genomeB)
assert(crossed[7]!.data.weight === 7); assert(crossed[7]!.data.innovation === 9);
assert(crossed[8]!.data.innovation === 10);
// excess
assert(crossed[8]!.data.innovation === 9);
assert(crossed[8]!.data.weight === 5);
assert(crossed[9]!.data.innovation === 10);
assert(crossed[9]!.data.weight === 5);
} }
addTest(testCrossGenomes); addTest(testCrossGenomes);
@ -380,21 +371,21 @@ function testFindAcyclicInternalNewConns() {
const options = findAcyclicInternalNewConns(edges); const options = findAcyclicInternalNewConns(edges);
const expected = [ const expected = [
// not A -> B - B is a source // not A -> B - B is a source
// not A -> F - F is a sink { src_id: 'A', dst_id: 'F' },
// not B -> A - A is a source // not B -> A - A is a source
{ src_id: 'B', dst_id: 'C' }, { src_id: 'B', dst_id: 'C' },
// not B -> E - E is a sink { src_id: 'B', dst_id: 'E' },
// not B -> F - F is a sink { src_id: 'B', dst_id: 'F' },
// not C -> B - B is a source // not C -> B - B is a source
{ src_id: 'C', dst_id: 'D' }, { src_id: 'C', dst_id: 'D' },
// not C -> F - F is a sink { src_id: 'C', dst_id: 'F' },
{ src_id: 'D', dst_id: 'C' }, { src_id: 'D', dst_id: 'C' },
// not E -> B - B is a parent of E // not E -> B - B is a parent of E
// not E -> F - F is a sink // not E -> F - F is a sink
// not F -> A - A is a parent of F // not F -> A - F is a sink
// not F -> B - B is a parent of F // not F -> B - F is a sink
{ src_id: 'F', dst_id: 'C' }, // not F -> C - F is a sink
// not F -> E - E is a sink // not F -> E - F is a sink
]; ];
function strcmp(a: string, b: string) { function strcmp(a: string, b: string) {