create new connection script

This commit is contained in:
Michael Peters 2024-08-26 14:09:48 -07:00
parent 871d531972
commit 889b0dfa77
3 changed files with 56 additions and 19 deletions

View File

@ -119,7 +119,7 @@
* - use random values connection weights
* - all nodes should be marked "enabled"
* - all organisms in the initial population will be part of species #1
* - this implies that randomized connection weights should be chosen such that c3*W < δ_t for all organisms
* - this implies that randomized connection weights should be chosen such that c3*W < δ_t for all organisms
* 3. Training Loop
* a) Compute fitness f_i for all organisms[i]
* b) Find adjusted f_adj_i based on each organism's species
@ -129,7 +129,7 @@
* - All organisms whose fitness passes the survival threshold (based on f_adj_i) are considered survivors
* - Add all survivors to the next generation
* - Fertility: fertility[i] = # of offspring from organism i as the mom
* - Organisms from species that have not improved in fitness for stag_lim generations are given fertility[i] = 0 (barred from mating)
* - Organisms from species that have not improved in fitness for stag_lim generations are given fertility[i] = 0 (barred from mating)
* - Count the number of fertile survivors
* - Compute general, remainder = divmod((pop - # survivors), # fertile survivors)
* - Select remainder organisms to be given one bonus fertility (at random or by fitness)
@ -228,8 +228,8 @@
* - Effectively pass data from inputs, through hidden nodes, to outputs
*/
import { edgesToNodes, NodeID, Network, Node, topoSort } from './network';
import { keyMax, randint, randomNegPos } from './util';
import { edgesToNodes, NodeID, Network, Node, topoSort, traceParents, RawEdge } from './network';
import { keyMax, randchoice, randint, randomNegPos, setDifference, setMap } from './util';
interface GeneData {
innovation: number;
@ -375,15 +375,6 @@ export function chooseSurvivors(population: Population, fitness: Map<Genome, num
// export function mate(a: Genome, b: Genome): Genome {}
function getGenomeNodeIDs(genome: Genome): Set<NodeID> {
const nodeIDs = new Set<NodeID>();
for (const gene of genome) {
nodeIDs.add(gene.src_id);
nodeIDs.add(gene.dst_id);
}
return nodeIDs;
}
interface MutateConfig {
mutate_rate: number; // chance to mutate a gene's weight
assign_rate: number; // chance to assign instead of uniformly perturb
@ -440,13 +431,53 @@ export function mutate(genome: Genome, config: MutateConfig): Genome {
}
if (Math.random() < new_connection_rate) {
// create a new connection
// between *previously unconnected* nodes
// TODO: use traceParents
// TODO: test traceParents
// create a new connection between two *previously unconnected* nodes
// NOTE: there's some performance stuff here that could definitely be improved
const nodes = edgesToNodes(newGenome);
const sources = Array.from(nodes.values()).filter(n => n.srcs.size === 0);
const sinks = Array.from(nodes.values()).filter(n => n.dsts.size === 0);
// find nodes that can be connected without creating a cycle
// a node that is connected to one of its parents creates a cycle
const allNodeIDs = new Set(nodes.keys());
const parents = traceParents(nodes);
const acyclic = new Map<NodeID, Set<NodeID>>();
for (const [nodeID, nodeParents] of parents.entries()) {
const nodeParentIDs = setMap(nodeParents, n => n.id);
const nodeAcyclic = setDifference(allNodeIDs, nodeParentIDs);
acyclic.set(nodeID, nodeAcyclic);
}
// flatten
const acyclicConns: { src_id: NodeID, dst_id: NodeID }[] = [];
for (const [nodeID, nodeAcyclic] of acyclic) {
acyclicConns.push(...setMap(nodeAcyclic, dst => ({ src_id: nodeID, dst_id: dst })));
}
// remove options that are already connected
const options: { src_id: NodeID, dst_id: NodeID }[] = [];
for (const conn of acyclicConns) {
if (newGenome.findIndex(c => c.src_id === conn.src_id && c.dst_id === conn.dst_id) === -1) continue;
options.push(conn);
}
// choose a random connection
if (options.length === 0) {
// TODO: remove this warn once this starts working
// this is mostly a sanity check / useful for metrics
console.warn("could not find a valid new connection!");
} else {
const { src_id, dst_id } = randchoice(options);
const newGene = {
src_id,
dst_id,
data: {
innovation: g_innovation_number++,
weight: assign_mag * randomNegPos(),
enabled: true,
}
}
newGenome.push(newGene);
}
}
return newGenome;
@ -476,6 +507,7 @@ export class Organism {
/** given per-node input activations, computes downstream activations in-place.*/
think(activations: Map<NodeID, number>) {
// TODO: do not follow disabled connections
for (const id of this.order) {
const node = this.network.get(id)!;
if (node.srcs.size === 0) {

View File

@ -7,6 +7,7 @@ export function randint(low: number, high: number) {
return Math.floor(Math.random() * range) + low;
}
/** deprecated, use util.ts instead */
export function randchoice<T>(arr: T[]) {
return arr[randint(0, arr.length - 1)]!;
}

View File

@ -9,6 +9,10 @@ export function randint(low: number, high: number) {
return Math.floor(Math.random() * range) + low;
}
export function randchoice<T>(arr: T[]) {
return arr[randint(0, arr.length - 1)]!;
}
export function keyMax<T>(values: Iterable<T>, keyFunc: (v: T) => number) {
let best = null;
for (const value of values) {