From 5444feee992d4ee054b1c2e7e5715e913e672bf1 Mon Sep 17 00:00:00 2001 From: Michael Peters Date: Sat, 3 Aug 2024 15:34:26 -0700 Subject: [PATCH] minor refactor --- src/components/snake/canvas.ts | 157 +----------------- .../snake/{hashmap.ts => hashset.ts} | 0 src/components/snake/snake-brain.ts | 106 ++++++++++++ src/components/snake/types.ts | 42 +++++ 4 files changed, 154 insertions(+), 151 deletions(-) rename src/components/snake/{hashmap.ts => hashset.ts} (100%) create mode 100644 src/components/snake/snake-brain.ts create mode 100644 src/components/snake/types.ts diff --git a/src/components/snake/canvas.ts b/src/components/snake/canvas.ts index 8588377..6efbdd1 100644 --- a/src/components/snake/canvas.ts +++ b/src/components/snake/canvas.ts @@ -1,6 +1,6 @@ -import { Brain, Layer, sigmoidNegPos } from './brain'; import { Engine, Keys, randint, UI, Vec2, vec2 } from './game-engine'; -import { HashSet } from './hashmap'; +import { SnakeBrain } from './snake-brain'; +import { SGSHashSet, Snake, SnakeGameState, SnakeGameStateWithHistory } from './types'; const BOARD_SIZE = 600; // px const SQUARE_SIZE = 30; // px @@ -8,49 +8,11 @@ const SQUARE_SIZE = 30; // px const CENTER_X = BOARD_SIZE / 2; const CENTER_Y = BOARD_SIZE / 2; -const BOARD_SQUARES = BOARD_SIZE / SQUARE_SIZE; +export const BOARD_SQUARES = BOARD_SIZE / SQUARE_SIZE; -type Snake = Vec2[]; -interface SnakeGameState { - dead: boolean; - snake: Snake; - apple: Vec2; -} +// general functions ----------------------------------------------------------- -interface SnakeGameStateWithHistory extends SnakeGameState { - history: SGSHashSet; -} - -class SGSHashSet extends HashSet { - static hash(state: SnakeGameState) { - const { dead, snake, apple } = state; - - if (dead) return -1; - - const snakeHash = snake.map(square => square.x + square.y).reduce((prev, curr) => prev + curr); - const appleHash = apple.x + apple.y; - const hash = snakeHash + appleHash; - return hash; - } - - static eq(a: SnakeGameState, b: SnakeGameState) { - if (a.snake.length !== b.snake.length) return false; - if (!a.apple.eq(b.apple)) return false; - if (a.dead !== b.dead) return false; - for (let i = 0; i < a.snake.length; ++i) { - if (!a.snake[i]!.eq(b.snake[i]!)) { - return false; - } - } - return true; - } - - constructor() { - super(SGSHashSet.hash, SGSHashSet.eq); - } -} - -function shallowCopySGS(state: SnakeGameState) { +function shallowCopySGS(state: SnakeGameState): SnakeGameState { return { dead: state.dead, snake: [...state.snake], @@ -58,109 +20,6 @@ function shallowCopySGS(state: SnakeGameState) { }; } -function isOutOfBounds(square: Vec2) { - return square.x < 0 || square.x >= BOARD_SQUARES || square.y < 0 || square.y >= BOARD_SQUARES; -} - -function isNextHeadInNextSnake(snake: Snake, nextHead: Vec2) { - for (const square of snake.slice(1)) { - if (nextHead.eq(square)) { - return true; - } - } - return false; -} - -class SnakeBrain { - brain: Brain; - - constructor(brain: Brain) { - this.brain = brain; - } - - think(state: SnakeGameState): Vec2 | 'dead' { - const { snake, apple } = state; - - const head = snake[snake.length - 1]!; - - const moves = [vec2(0, +1), vec2(0, -1), vec2(+1, 0), vec2(-1, 0)]; - const nextHeads = moves.map(m => head.add(m)); - const valid = nextHeads.map(nh => !isOutOfBounds(nh) && !isNextHeadInNextSnake(snake, nh)); - - const firstValidIdx = valid.findIndex(v => v); - if (firstValidIdx === -1) return 'dead'; - - // feature layer: - // - head x & y - // - relative apple x & y to head - // - +/- x, +/- y distance to tail - // future ideas: - // - +/- x, +/- y distance to tail at t+1 - - const appleRel = apple.sub(head); - - const above: number[] = []; - const below: number[] = []; - const left: number[] = []; - const right: number[] = []; - for (let i = 0; i < snake.length - 1; ++i) { - const tail = snake[i]!; - if (tail.x === head.x) { - if (tail.y > head.y) above.push(tail.y - head.y); - if (tail.y < head.y) below.push(head.y - tail.y); - } - if (tail.y === head.y) { - if (tail.x > head.x) right.push(tail.x - head.x); - if (tail.x < head.x) left.push(head.x - tail.x); - } - } - - // 8 inputs, ... hidden nodes, 4 outputs - const input = [ - head.x, - head.y, - appleRel.x, - appleRel.y, - Math.max(...above, BOARD_SQUARES), - Math.max(...below, BOARD_SQUARES), - Math.max(...left, BOARD_SQUARES), - Math.max(...right, BOARD_SQUARES), - ]; - const output = this.brain.think(input); - - const moveIdx = output.reduce( - (prevIdx, curr, idx) => (valid[idx]! && curr > output[prevIdx]! ? idx : prevIdx), - firstValidIdx, - ); - - const move = moves[moveIdx]!; - return move; - } - - static fromRandom({ hiddenLayerNodes }: { hiddenLayerNodes: number }) { - const INPUT_NODES = 8; - const OUTPUT_NODES = 4; - - const hiddenLayer = Layer.makeRandomLayer({ - inputs: INPUT_NODES, - outputs: hiddenLayerNodes, - mag: 1, - activation: sigmoidNegPos, - }); - const outputLayer = Layer.makeRandomLayer({ - inputs: hiddenLayerNodes, - outputs: OUTPUT_NODES, - mag: 1, - activation: sigmoidNegPos, - }); - - const brain = new Brain([hiddenLayer, outputLayer]); - return new SnakeBrain(brain); - } -} - -// general functions ----------------------------------------------------------- - function getRandApplePos() { return vec2(randint(0, BOARD_SQUARES), randint(0, BOARD_SQUARES)); } @@ -202,6 +61,7 @@ export default function runCanvas(canvas: HTMLCanvasElement) { } history.add(shallowCopySGS(state)); + // perform ai const dir = brain.think(state); // NOTE: brain.think handles out-of-bounds/tail intersect checking when it identifies @@ -213,11 +73,6 @@ export default function runCanvas(canvas: HTMLCanvasElement) { const nextHead = getSnakeNextSquare(snake, dir); - // check for snake out of bounds or intersection with tail - if (isOutOfBounds(nextHead) || isNextHeadInNextSnake(snake, nextHead)) { - state.dead = true; - } - // check for snake hitting apple if (nextHead.eq(apple)) { state.apple = getRandApplePos(); diff --git a/src/components/snake/hashmap.ts b/src/components/snake/hashset.ts similarity index 100% rename from src/components/snake/hashmap.ts rename to src/components/snake/hashset.ts diff --git a/src/components/snake/snake-brain.ts b/src/components/snake/snake-brain.ts new file mode 100644 index 0000000..33d4b9c --- /dev/null +++ b/src/components/snake/snake-brain.ts @@ -0,0 +1,106 @@ +import { Brain, Layer, sigmoidNegPos } from './brain'; +import { BOARD_SQUARES } from './canvas'; +import { vec2, Vec2 } from './game-engine'; +import { Snake, SnakeGameState } from './types'; + +function isOutOfBounds(square: Vec2) { + return square.x < 0 || square.x >= BOARD_SQUARES || square.y < 0 || square.y >= BOARD_SQUARES; +} + +function isNextHeadInNextSnake(snake: Snake, nextHead: Vec2) { + for (const square of snake.slice(1)) { + if (nextHead.eq(square)) { + return true; + } + } + return false; +} + +const MOVES = [vec2(0, +1), vec2(0, -1), vec2(+1, 0), vec2(-1, 0)]; + +export class SnakeBrain { + brain: Brain; + + constructor(brain: Brain) { + this.brain = brain; + } + + think(state: SnakeGameState): Vec2 | 'dead' { + const { snake, apple } = state; + + const head = snake[snake.length - 1]!; + + const nextHeads = MOVES.map(m => head.add(m)); + const valid = nextHeads.map(nh => !isOutOfBounds(nh) && !isNextHeadInNextSnake(snake, nh)); + + const firstValidIdx = valid.findIndex(v => v); + if (firstValidIdx === -1) return 'dead'; + + // feature layer: + // - head x & y + // - relative apple x & y to head + // - +/- x, +/- y distance to tail + // future ideas: + // - +/- x, +/- y distance to tail at t+1 + + const appleRel = apple.sub(head); + + const above: number[] = []; + const below: number[] = []; + const left: number[] = []; + const right: number[] = []; + for (let i = 0; i < snake.length - 1; ++i) { + const tail = snake[i]!; + if (tail.x === head.x) { + if (tail.y > head.y) above.push(tail.y - head.y); + if (tail.y < head.y) below.push(head.y - tail.y); + } + if (tail.y === head.y) { + if (tail.x > head.x) right.push(tail.x - head.x); + if (tail.x < head.x) left.push(head.x - tail.x); + } + } + + // 8 inputs, ... hidden nodes, 4 outputs + const input = [ + head.x, + head.y, + appleRel.x, + appleRel.y, + Math.max(...above, BOARD_SQUARES), + Math.max(...below, BOARD_SQUARES), + Math.max(...left, BOARD_SQUARES), + Math.max(...right, BOARD_SQUARES), + ]; + const output = this.brain.think(input); + + const moveIdx = output.reduce( + (prevIdx, curr, idx) => (valid[idx]! && curr > output[prevIdx]! ? idx : prevIdx), + firstValidIdx, + ); + + const move = MOVES[moveIdx]!; + return move; + } + + static fromRandom({ hiddenLayerNodes }: { hiddenLayerNodes: number }) { + const INPUT_NODES = 8; + const OUTPUT_NODES = 4; + + const hiddenLayer = Layer.makeRandomLayer({ + inputs: INPUT_NODES, + outputs: hiddenLayerNodes, + mag: 1, + activation: sigmoidNegPos, + }); + const outputLayer = Layer.makeRandomLayer({ + inputs: hiddenLayerNodes, + outputs: OUTPUT_NODES, + mag: 1, + activation: sigmoidNegPos, + }); + + const brain = new Brain([hiddenLayer, outputLayer]); + return new SnakeBrain(brain); + } +} diff --git a/src/components/snake/types.ts b/src/components/snake/types.ts new file mode 100644 index 0000000..41b5f55 --- /dev/null +++ b/src/components/snake/types.ts @@ -0,0 +1,42 @@ +import { Vec2 } from './game-engine'; +import { HashSet } from './hashset'; + +export type Snake = Vec2[]; +export interface SnakeGameState { + dead: boolean; + snake: Snake; + apple: Vec2; +} + +export interface SnakeGameStateWithHistory extends SnakeGameState { + history: SGSHashSet; +} + +export class SGSHashSet extends HashSet { + static hash(state: SnakeGameState) { + const { dead, snake, apple } = state; + + if (dead) return -1; + + const snakeHash = snake.map(square => square.x + square.y).reduce((prev, curr) => prev + curr); + const appleHash = apple.x + apple.y; + const hash = snakeHash + appleHash; + return hash; + } + + static eq(a: SnakeGameState, b: SnakeGameState) { + if (a.snake.length !== b.snake.length) return false; + if (!a.apple.eq(b.apple)) return false; + if (a.dead !== b.dead) return false; + for (let i = 0; i < a.snake.length; ++i) { + if (!a.snake[i]!.eq(b.snake[i]!)) { + return false; + } + } + return true; + } + + constructor() { + super(SGSHashSet.hash, SGSHashSet.eq); + } +}