get started with actually running the simulation using the NEAT brain

This commit is contained in:
Michael Peters 2024-09-03 20:48:38 -07:00
parent c631383087
commit 9e624f225e
4 changed files with 56 additions and 27 deletions

View File

@ -1,19 +1,20 @@
import { MutableRefObject } from 'react';
import { MutationConfig } from './brain';
import { clamp, Engine, Keys, randchoice, randint, UI, Vec2, vec2 } from './game-engine';
import { SnakeBrain } from './snake-brain';
import { Engine, Keys, randint, UI, Vec2, vec2 } from './game-engine';
import { SGSHashSet, Snake, SnakeGameState, SnakeGameStateWithHistory } from './types';
import { PipeRef } from '.';
import {
assignSpecies,
CompatibilityDistanceConfig,
CompatibilityDistanceThreshold,
computeNextGeneration,
CrossConfig,
FertilityConfig,
Genome,
MateChoiceConfig,
mutate,
MutateConfig,
NextGenerationConfig,
resetGlobalIDs,
} from './neat';
import { BASE_GENOME_SNAKE_BRAIN_NEAT, NEATSnakeBrain } from './neat-snake-brain';
@ -36,8 +37,8 @@ interface LabColors {
export interface SnakeGameTrainerLab {
id: number;
colors: LabColors;
brain: SnakeBrain;
state: SnakeGameStateWithHistory;
brain: NEATSnakeBrain;
}
interface SnakeGameTrainerState {
@ -159,16 +160,27 @@ export default function runCanvas(canvas: HTMLCanvasElement, pipeRef: MutableRef
c3: 1,
};
const CDT: CompatibilityDistanceThreshold = 1.0;
const NGC: NextGenerationConfig = {
fc: FC,
mcc: MCC,
cc: CC,
mc: MC,
cdc: CDC,
cdt: CDT,
};
// general simulation ------------------------------------------------------
const generation = 1;
let generation = 1;
// labs & initial population -----------------------------------------------
const initialGenomes = new Array(SNAKES).map(() => mutate(BASE_GENOME_SNAKE_BRAIN_NEAT, MC));
resetGlobalIDs({ node_id: 1, innovation_number: 1, species_id: 1 });
// eslint-disable-next-line prefer-spread
const initialGenomes = Array.apply(null, Array(SNAKES)).map(() => mutate(BASE_GENOME_SNAKE_BRAIN_NEAT, MC));
// assign initial species
const population = new Map();
const reps = new Map();
let population = new Map();
let reps = new Map();
for (const genome of initialGenomes) {
const sid = assignSpecies(genome, reps, CDC, CDT);
population.set(genome, sid);
@ -176,8 +188,8 @@ export default function runCanvas(canvas: HTMLCanvasElement, pipeRef: MutableRef
let nextLabId = 1;
const trainer: SnakeGameTrainerState = {
labs: Array.from({ length: SNAKES }).map(_ => makeRandomLab({ id: nextLabId++, hiddenLayerNodes: 6 })),
let trainer: SnakeGameTrainerState = {
labs: initialGenomes.map(g => makeLab(nextLabId++, g)),
};
function update() {
@ -191,11 +203,21 @@ export default function runCanvas(canvas: HTMLCanvasElement, pipeRef: MutableRef
engine.setUpdateDelay(0);
}
// TODO: compute next generation when all snakes are dead
// cull weak when all snakes are dead
const allDead = trainer.labs.findIndex(l => !l.state.dead) === -1;
if (allDead) {
// cullWeakFearStrong();
console.log('computing next gen');
// compute next generation
const fitness = new Map<Genome, number>();
for (const lab of trainer.labs) {
fitness.set(lab.brain.brain.genome, lab.state.snake.length);
}
const { nextPopulation, nextReps } = computeNextGeneration(population, fitness, NGC);
population = nextPopulation;
reps = nextReps;
trainer = {
labs: Array.from(population.keys()).map(g => makeLab(nextLabId++, g)),
};
generation++;
}
for (const lab of trainer.labs) {
@ -321,5 +343,5 @@ export default function runCanvas(canvas: HTMLCanvasElement, pipeRef: MutableRef
}
keys.bindKeys();
// engine.run(update, render);
engine.run(update, render);
}

View File

@ -352,6 +352,8 @@ export function tournamentSelectionWithChampions(
// compute adjusted fitness by scaling fitness by species size
const adjFitness = mapMap(fitness, (k, v) => {
const sid = population.get(k)!;
const spec = species.get(sid)!;
if (!spec) debugger;
const speciesSize = species.get(sid)!.size;
return [k, v / speciesSize];
});
@ -551,7 +553,7 @@ export function mutate(genome: Genome, config: MutateConfig): Genome {
const { mutate_rate, assign_rate, assign_mag, perturb_mag, new_node_rate, new_connection_rate } = config;
const newGenome = genome.map(gene => {
if (Math.random() < mutate_rate) return gene; // this connection should not be mutated
if (Math.random() >= mutate_rate) return gene; // this connection should not be mutated
if (Math.random() < assign_rate) {
return mutateAssign(gene, assign_mag * randomNegPos());
} else {
@ -609,10 +611,10 @@ export function computeNextGeneration(
const winnersPopulation = new Map(winners.map(w => [w, population.get(w)!]));
// copy over champions to the next generation and use them as representatives for their species
const nextGeneration = mapMap(champions, (sid, c) => [c, sid]);
const reps = new Map(champions);
const nextPopulation = mapMap(champions, (sid, c) => [c, sid]);
const nextReps = new Map(champions);
while (nextGeneration.size < population.size) {
while (nextPopulation.size < population.size) {
// mate
const mom = randchoice(winners);
const dad = chooseMate(mom, winnersPopulation, mcc);
@ -620,9 +622,9 @@ export function computeNextGeneration(
const baby = mutate(crossed, mc);
// assign to a species + add to next generation
const sid = assignSpecies(baby, reps, cdc, cdt);
nextGeneration.set(baby, sid);
const sid = assignSpecies(baby, nextReps, cdc, cdt);
nextPopulation.set(baby, sid);
}
return nextGeneration;
return { nextPopulation, nextReps };
}

View File

@ -77,13 +77,18 @@ export function traceParents<DataT>(nodes: Network<Node<DataT>>) {
// this function is O(Nodes+Edges)
const parents = new Map<NodeID, Set<Node<DataT>>>();
function traceNodeParents(node: Node<DataT>): Set<Node<DataT>> {
function traceNodeParents(node: Node<DataT>, depth: number): Set<Node<DataT>> {
console.log('tracing: ', node.id, depth);
// TODO: there's a particularly nasty issue where maybe nodes are not just from their own genome?
if (depth > 1000) {
debugger;
}
if (parents.has(node.id)) return parents.get(node.id)!;
const nodeParents = new Set<Node<DataT>>();
for (const edgeSrc of node.srcs) {
nodeParents.add(edgeSrc.src);
const edgeParents = traceNodeParents(edgeSrc.src);
const edgeParents = traceNodeParents(edgeSrc.src, depth + 1);
Array.from(edgeParents).forEach(n => nodeParents.add(n));
}
parents.set(node.id, nodeParents);
@ -91,7 +96,7 @@ export function traceParents<DataT>(nodes: Network<Node<DataT>>) {
}
for (const node of nodes.values()) {
traceNodeParents(node);
traceNodeParents(node, 0);
}
return parents;
}

View File

@ -461,13 +461,13 @@ function testComputeNextGeneration() {
};
resetGlobalIDs({ node_id: 1, innovation_number: 2, species_id: 3 });
const ng = computeNextGeneration(population, fitness, cngc);
const { nextPopulation: np, nextReps: _nextReps } = computeNextGeneration(population, fitness, cngc);
// NOTE: these tests are not very detailed as this is difficult
// to test without mocks
assert(ng.size === population.size);
assert(np.size === population.size);
const sids = new Set(ng.values());
const sids = new Set(np.values());
assert(sids.has(1));
assert(!sids.has(2));
// it not guaranteed that sids.has(3) since the new genomes may have the