create new connection script

This commit is contained in:
Michael Peters 2024-08-26 14:09:48 -07:00
parent 871d531972
commit 889b0dfa77
3 changed files with 56 additions and 19 deletions

View File

@ -228,8 +228,8 @@
* - Effectively pass data from inputs, through hidden nodes, to outputs * - Effectively pass data from inputs, through hidden nodes, to outputs
*/ */
import { edgesToNodes, NodeID, Network, Node, topoSort } from './network'; import { edgesToNodes, NodeID, Network, Node, topoSort, traceParents, RawEdge } from './network';
import { keyMax, randint, randomNegPos } from './util'; import { keyMax, randchoice, randint, randomNegPos, setDifference, setMap } from './util';
interface GeneData { interface GeneData {
innovation: number; innovation: number;
@ -375,15 +375,6 @@ export function chooseSurvivors(population: Population, fitness: Map<Genome, num
// export function mate(a: Genome, b: Genome): Genome {} // export function mate(a: Genome, b: Genome): Genome {}
function getGenomeNodeIDs(genome: Genome): Set<NodeID> {
const nodeIDs = new Set<NodeID>();
for (const gene of genome) {
nodeIDs.add(gene.src_id);
nodeIDs.add(gene.dst_id);
}
return nodeIDs;
}
interface MutateConfig { interface MutateConfig {
mutate_rate: number; // chance to mutate a gene's weight mutate_rate: number; // chance to mutate a gene's weight
assign_rate: number; // chance to assign instead of uniformly perturb assign_rate: number; // chance to assign instead of uniformly perturb
@ -440,13 +431,53 @@ export function mutate(genome: Genome, config: MutateConfig): Genome {
} }
if (Math.random() < new_connection_rate) { if (Math.random() < new_connection_rate) {
// create a new connection // create a new connection between two *previously unconnected* nodes
// between *previously unconnected* nodes // NOTE: there's some performance stuff here that could definitely be improved
// TODO: use traceParents
// TODO: test traceParents
const nodes = edgesToNodes(newGenome); const nodes = edgesToNodes(newGenome);
const sources = Array.from(nodes.values()).filter(n => n.srcs.size === 0);
const sinks = Array.from(nodes.values()).filter(n => n.dsts.size === 0); // find nodes that can be connected without creating a cycle
// a node that is connected to one of its parents creates a cycle
const allNodeIDs = new Set(nodes.keys());
const parents = traceParents(nodes);
const acyclic = new Map<NodeID, Set<NodeID>>();
for (const [nodeID, nodeParents] of parents.entries()) {
const nodeParentIDs = setMap(nodeParents, n => n.id);
const nodeAcyclic = setDifference(allNodeIDs, nodeParentIDs);
acyclic.set(nodeID, nodeAcyclic);
}
// flatten
const acyclicConns: { src_id: NodeID, dst_id: NodeID }[] = [];
for (const [nodeID, nodeAcyclic] of acyclic) {
acyclicConns.push(...setMap(nodeAcyclic, dst => ({ src_id: nodeID, dst_id: dst })));
}
// remove options that are already connected
const options: { src_id: NodeID, dst_id: NodeID }[] = [];
for (const conn of acyclicConns) {
if (newGenome.findIndex(c => c.src_id === conn.src_id && c.dst_id === conn.dst_id) === -1) continue;
options.push(conn);
}
// choose a random connection
if (options.length === 0) {
// TODO: remove this warn once this starts working
// this is mostly a sanity check / useful for metrics
console.warn("could not find a valid new connection!");
} else {
const { src_id, dst_id } = randchoice(options);
const newGene = {
src_id,
dst_id,
data: {
innovation: g_innovation_number++,
weight: assign_mag * randomNegPos(),
enabled: true,
}
}
newGenome.push(newGene);
}
} }
return newGenome; return newGenome;
@ -476,6 +507,7 @@ export class Organism {
/** given per-node input activations, computes downstream activations in-place.*/ /** given per-node input activations, computes downstream activations in-place.*/
think(activations: Map<NodeID, number>) { think(activations: Map<NodeID, number>) {
// TODO: do not follow disabled connections
for (const id of this.order) { for (const id of this.order) {
const node = this.network.get(id)!; const node = this.network.get(id)!;
if (node.srcs.size === 0) { if (node.srcs.size === 0) {

View File

@ -7,6 +7,7 @@ export function randint(low: number, high: number) {
return Math.floor(Math.random() * range) + low; return Math.floor(Math.random() * range) + low;
} }
/** deprecated, use util.ts instead */
export function randchoice<T>(arr: T[]) { export function randchoice<T>(arr: T[]) {
return arr[randint(0, arr.length - 1)]!; return arr[randint(0, arr.length - 1)]!;
} }

View File

@ -9,6 +9,10 @@ export function randint(low: number, high: number) {
return Math.floor(Math.random() * range) + low; return Math.floor(Math.random() * range) + low;
} }
export function randchoice<T>(arr: T[]) {
return arr[randint(0, arr.length - 1)]!;
}
export function keyMax<T>(values: Iterable<T>, keyFunc: (v: T) => number) { export function keyMax<T>(values: Iterable<T>, keyFunc: (v: T) => number) {
let best = null; let best = null;
for (const value of values) { for (const value of values) {