minor refactor

This commit is contained in:
Michael Peters 2024-08-03 15:34:26 -07:00
parent 8783c7a1a7
commit 5444feee99
4 changed files with 154 additions and 151 deletions

View File

@ -1,6 +1,6 @@
import { Brain, Layer, sigmoidNegPos } from './brain';
import { Engine, Keys, randint, UI, Vec2, vec2 } from './game-engine';
import { HashSet } from './hashmap';
import { SnakeBrain } from './snake-brain';
import { SGSHashSet, Snake, SnakeGameState, SnakeGameStateWithHistory } from './types';
const BOARD_SIZE = 600; // px
const SQUARE_SIZE = 30; // px
@ -8,49 +8,11 @@ const SQUARE_SIZE = 30; // px
const CENTER_X = BOARD_SIZE / 2;
const CENTER_Y = BOARD_SIZE / 2;
const BOARD_SQUARES = BOARD_SIZE / SQUARE_SIZE;
export const BOARD_SQUARES = BOARD_SIZE / SQUARE_SIZE;
type Snake = Vec2[];
interface SnakeGameState {
dead: boolean;
snake: Snake;
apple: Vec2;
}
// general functions -----------------------------------------------------------
interface SnakeGameStateWithHistory extends SnakeGameState {
history: SGSHashSet;
}
class SGSHashSet extends HashSet<SnakeGameState> {
static hash(state: SnakeGameState) {
const { dead, snake, apple } = state;
if (dead) return -1;
const snakeHash = snake.map(square => square.x + square.y).reduce((prev, curr) => prev + curr);
const appleHash = apple.x + apple.y;
const hash = snakeHash + appleHash;
return hash;
}
static eq(a: SnakeGameState, b: SnakeGameState) {
if (a.snake.length !== b.snake.length) return false;
if (!a.apple.eq(b.apple)) return false;
if (a.dead !== b.dead) return false;
for (let i = 0; i < a.snake.length; ++i) {
if (!a.snake[i]!.eq(b.snake[i]!)) {
return false;
}
}
return true;
}
constructor() {
super(SGSHashSet.hash, SGSHashSet.eq);
}
}
function shallowCopySGS(state: SnakeGameState) {
function shallowCopySGS(state: SnakeGameState): SnakeGameState {
return {
dead: state.dead,
snake: [...state.snake],
@ -58,109 +20,6 @@ function shallowCopySGS(state: SnakeGameState) {
};
}
function isOutOfBounds(square: Vec2) {
return square.x < 0 || square.x >= BOARD_SQUARES || square.y < 0 || square.y >= BOARD_SQUARES;
}
function isNextHeadInNextSnake(snake: Snake, nextHead: Vec2) {
for (const square of snake.slice(1)) {
if (nextHead.eq(square)) {
return true;
}
}
return false;
}
class SnakeBrain {
brain: Brain;
constructor(brain: Brain) {
this.brain = brain;
}
think(state: SnakeGameState): Vec2 | 'dead' {
const { snake, apple } = state;
const head = snake[snake.length - 1]!;
const moves = [vec2(0, +1), vec2(0, -1), vec2(+1, 0), vec2(-1, 0)];
const nextHeads = moves.map(m => head.add(m));
const valid = nextHeads.map(nh => !isOutOfBounds(nh) && !isNextHeadInNextSnake(snake, nh));
const firstValidIdx = valid.findIndex(v => v);
if (firstValidIdx === -1) return 'dead';
// feature layer:
// - head x & y
// - relative apple x & y to head
// - +/- x, +/- y distance to tail
// future ideas:
// - +/- x, +/- y distance to tail at t+1
const appleRel = apple.sub(head);
const above: number[] = [];
const below: number[] = [];
const left: number[] = [];
const right: number[] = [];
for (let i = 0; i < snake.length - 1; ++i) {
const tail = snake[i]!;
if (tail.x === head.x) {
if (tail.y > head.y) above.push(tail.y - head.y);
if (tail.y < head.y) below.push(head.y - tail.y);
}
if (tail.y === head.y) {
if (tail.x > head.x) right.push(tail.x - head.x);
if (tail.x < head.x) left.push(head.x - tail.x);
}
}
// 8 inputs, ... hidden nodes, 4 outputs
const input = [
head.x,
head.y,
appleRel.x,
appleRel.y,
Math.max(...above, BOARD_SQUARES),
Math.max(...below, BOARD_SQUARES),
Math.max(...left, BOARD_SQUARES),
Math.max(...right, BOARD_SQUARES),
];
const output = this.brain.think(input);
const moveIdx = output.reduce(
(prevIdx, curr, idx) => (valid[idx]! && curr > output[prevIdx]! ? idx : prevIdx),
firstValidIdx,
);
const move = moves[moveIdx]!;
return move;
}
static fromRandom({ hiddenLayerNodes }: { hiddenLayerNodes: number }) {
const INPUT_NODES = 8;
const OUTPUT_NODES = 4;
const hiddenLayer = Layer.makeRandomLayer({
inputs: INPUT_NODES,
outputs: hiddenLayerNodes,
mag: 1,
activation: sigmoidNegPos,
});
const outputLayer = Layer.makeRandomLayer({
inputs: hiddenLayerNodes,
outputs: OUTPUT_NODES,
mag: 1,
activation: sigmoidNegPos,
});
const brain = new Brain([hiddenLayer, outputLayer]);
return new SnakeBrain(brain);
}
}
// general functions -----------------------------------------------------------
function getRandApplePos() {
return vec2(randint(0, BOARD_SQUARES), randint(0, BOARD_SQUARES));
}
@ -202,6 +61,7 @@ export default function runCanvas(canvas: HTMLCanvasElement) {
}
history.add(shallowCopySGS(state));
// perform ai
const dir = brain.think(state);
// NOTE: brain.think handles out-of-bounds/tail intersect checking when it identifies
@ -213,11 +73,6 @@ export default function runCanvas(canvas: HTMLCanvasElement) {
const nextHead = getSnakeNextSquare(snake, dir);
// check for snake out of bounds or intersection with tail
if (isOutOfBounds(nextHead) || isNextHeadInNextSnake(snake, nextHead)) {
state.dead = true;
}
// check for snake hitting apple
if (nextHead.eq(apple)) {
state.apple = getRandApplePos();

View File

@ -0,0 +1,106 @@
import { Brain, Layer, sigmoidNegPos } from './brain';
import { BOARD_SQUARES } from './canvas';
import { vec2, Vec2 } from './game-engine';
import { Snake, SnakeGameState } from './types';
function isOutOfBounds(square: Vec2) {
return square.x < 0 || square.x >= BOARD_SQUARES || square.y < 0 || square.y >= BOARD_SQUARES;
}
function isNextHeadInNextSnake(snake: Snake, nextHead: Vec2) {
for (const square of snake.slice(1)) {
if (nextHead.eq(square)) {
return true;
}
}
return false;
}
const MOVES = [vec2(0, +1), vec2(0, -1), vec2(+1, 0), vec2(-1, 0)];
export class SnakeBrain {
brain: Brain;
constructor(brain: Brain) {
this.brain = brain;
}
think(state: SnakeGameState): Vec2 | 'dead' {
const { snake, apple } = state;
const head = snake[snake.length - 1]!;
const nextHeads = MOVES.map(m => head.add(m));
const valid = nextHeads.map(nh => !isOutOfBounds(nh) && !isNextHeadInNextSnake(snake, nh));
const firstValidIdx = valid.findIndex(v => v);
if (firstValidIdx === -1) return 'dead';
// feature layer:
// - head x & y
// - relative apple x & y to head
// - +/- x, +/- y distance to tail
// future ideas:
// - +/- x, +/- y distance to tail at t+1
const appleRel = apple.sub(head);
const above: number[] = [];
const below: number[] = [];
const left: number[] = [];
const right: number[] = [];
for (let i = 0; i < snake.length - 1; ++i) {
const tail = snake[i]!;
if (tail.x === head.x) {
if (tail.y > head.y) above.push(tail.y - head.y);
if (tail.y < head.y) below.push(head.y - tail.y);
}
if (tail.y === head.y) {
if (tail.x > head.x) right.push(tail.x - head.x);
if (tail.x < head.x) left.push(head.x - tail.x);
}
}
// 8 inputs, ... hidden nodes, 4 outputs
const input = [
head.x,
head.y,
appleRel.x,
appleRel.y,
Math.max(...above, BOARD_SQUARES),
Math.max(...below, BOARD_SQUARES),
Math.max(...left, BOARD_SQUARES),
Math.max(...right, BOARD_SQUARES),
];
const output = this.brain.think(input);
const moveIdx = output.reduce(
(prevIdx, curr, idx) => (valid[idx]! && curr > output[prevIdx]! ? idx : prevIdx),
firstValidIdx,
);
const move = MOVES[moveIdx]!;
return move;
}
static fromRandom({ hiddenLayerNodes }: { hiddenLayerNodes: number }) {
const INPUT_NODES = 8;
const OUTPUT_NODES = 4;
const hiddenLayer = Layer.makeRandomLayer({
inputs: INPUT_NODES,
outputs: hiddenLayerNodes,
mag: 1,
activation: sigmoidNegPos,
});
const outputLayer = Layer.makeRandomLayer({
inputs: hiddenLayerNodes,
outputs: OUTPUT_NODES,
mag: 1,
activation: sigmoidNegPos,
});
const brain = new Brain([hiddenLayer, outputLayer]);
return new SnakeBrain(brain);
}
}

View File

@ -0,0 +1,42 @@
import { Vec2 } from './game-engine';
import { HashSet } from './hashset';
export type Snake = Vec2[];
export interface SnakeGameState {
dead: boolean;
snake: Snake;
apple: Vec2;
}
export interface SnakeGameStateWithHistory extends SnakeGameState {
history: SGSHashSet;
}
export class SGSHashSet extends HashSet<SnakeGameState> {
static hash(state: SnakeGameState) {
const { dead, snake, apple } = state;
if (dead) return -1;
const snakeHash = snake.map(square => square.x + square.y).reduce((prev, curr) => prev + curr);
const appleHash = apple.x + apple.y;
const hash = snakeHash + appleHash;
return hash;
}
static eq(a: SnakeGameState, b: SnakeGameState) {
if (a.snake.length !== b.snake.length) return false;
if (!a.apple.eq(b.apple)) return false;
if (a.dead !== b.dead) return false;
for (let i = 0; i < a.snake.length; ++i) {
if (!a.snake[i]!.eq(b.snake[i]!)) {
return false;
}
}
return true;
}
constructor() {
super(SGSHashSet.hash, SGSHashSet.eq);
}
}