265 lines
6.6 KiB
Python
265 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
import typing as tp
|
|
from dataclasses import dataclass
|
|
from functools import cache
|
|
|
|
Player = tp.Literal["X", "O"]
|
|
Winner = Player | tp.Literal["draw"]
|
|
Square = Player | tp.Literal[" "]
|
|
Board = tuple[Square, Square, Square, Square, Square, Square, Square, Square, Square]
|
|
|
|
# TODO: "get best move"
|
|
# TODO: alpha-beta pruning
|
|
# TODO: rotational / reflectional board parity (for less total nodes)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GameState:
|
|
player: Player
|
|
board: Board
|
|
|
|
def __str__(self):
|
|
b = self.board
|
|
return (
|
|
f" {b[0]} │ {b[1]} │ {b[2]}\n"
|
|
"───┼───┼───\n"
|
|
f" {b[3]} │ {b[4]} │ {b[5]}\n"
|
|
"───┼───┼───\n"
|
|
f" {b[6]} │ {b[7]} │ {b[8]}"
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Move:
|
|
position: int
|
|
|
|
|
|
@cache
|
|
def get_valid_moves(state: GameState) -> tuple[Move]:
|
|
return tuple(
|
|
Move(position=i) for i, square in enumerate(state.board) if square == " "
|
|
)
|
|
|
|
|
|
@cache
|
|
def apply_move(state: GameState, move: Move) -> GameState:
|
|
new_player = "X" if state.player == "O" else "O"
|
|
|
|
new_board = list(state.board)
|
|
new_board[move.position] = state.player
|
|
new_board = tuple(new_board)
|
|
|
|
return GameState(player=new_player, board=new_board)
|
|
|
|
|
|
@cache
|
|
def get_winner(state: GameState) -> Winner | None:
|
|
# abc
|
|
# def
|
|
# ghi
|
|
a, b, c, d, e, f, g, h, i = state.board
|
|
|
|
# horizontal
|
|
if a == b == c and a != " ":
|
|
return a
|
|
if d == e == f and d != " ":
|
|
return d
|
|
if g == h == i and g != " ":
|
|
return g
|
|
|
|
# vertical
|
|
if a == d == g and a != " ":
|
|
return a
|
|
if b == e == h and b != " ":
|
|
return b
|
|
if c == f == i and c != " ":
|
|
return c
|
|
|
|
# diagonal
|
|
if a == e == i and a != " ":
|
|
return a
|
|
if c == e == g and c != " ":
|
|
return c
|
|
|
|
# draw
|
|
if not any(square == " " for square in state.board):
|
|
return "draw"
|
|
|
|
# no winner
|
|
return None
|
|
|
|
|
|
@cache
|
|
def get_next_states(state: GameState) -> tuple[GameState, ...]:
|
|
assert get_winner(state) is None, "should not be called if game ended"
|
|
return tuple(apply_move(state, move) for move in get_valid_moves(state))
|
|
|
|
|
|
Score = int
|
|
|
|
|
|
@cache
|
|
def get_score(target: Player, state: GameState) -> Score:
|
|
global_manage(state)
|
|
|
|
winner = get_winner(state)
|
|
if winner == target:
|
|
return 1
|
|
if winner == "draw":
|
|
return 0
|
|
if winner is not None:
|
|
# winner must be the opponent
|
|
return -1
|
|
|
|
agg = max if state.player == target else min
|
|
score = agg(get_score(target, next_state) for next_state in get_next_states(state))
|
|
return score
|
|
|
|
|
|
ScoreOrPruned = Score | tp.Literal["pruned"]
|
|
|
|
ScoreAgg = tp.Callable[[Score, Score], Score]
|
|
ScoreOrPruneAgg = tp.Callable[[Score, ScoreOrPruned], Score]
|
|
|
|
|
|
def smax(a: Score, b: Score) -> Score:
|
|
return max(a, b)
|
|
|
|
|
|
def smin(a: Score, b: Score) -> Score:
|
|
return min(a, b)
|
|
|
|
|
|
ScoreABPruneCallable = tp.Callable[[Player, GameState, Score | None], ScoreOrPruned]
|
|
|
|
|
|
memo_misses = 0
|
|
|
|
|
|
def ab_prune_cache(func: ScoreABPruneCallable) -> ScoreABPruneCallable:
|
|
# memoization maps player, gamestate -> resulting score (or pruned), prune cutoff
|
|
memo: dict[tuple[Player, GameState], tuple[ScoreOrPruned, Score | None]] = {}
|
|
|
|
# TODO: this could be further improved by starting at the pruned score and skipping
|
|
def cached(
|
|
target: Player, state: GameState, prune_cutoff: Score | None
|
|
) -> ScoreOrPruned:
|
|
if (target, state) in memo:
|
|
memo_score, memo_prune_cutoff = memo[(target, state)]
|
|
agg_prune = smin if state.player == target else smax
|
|
if memo_prune_cutoff is None or (
|
|
prune_cutoff is not None
|
|
and agg_prune(prune_cutoff, memo_prune_cutoff) == prune_cutoff
|
|
):
|
|
# breakpoint() # michael
|
|
return memo_score
|
|
if (target, state) in memo:
|
|
global memo_misses
|
|
memo_misses += 1
|
|
# breakpoint()
|
|
score = func(target, state, prune_cutoff)
|
|
memo[(target, state)] = (score, prune_cutoff)
|
|
return score
|
|
|
|
return cached
|
|
|
|
|
|
@ab_prune_cache
|
|
def get_score_ab_prune(
|
|
target: Player, state: GameState, prune_cutoff: Score | None
|
|
) -> ScoreOrPruned:
|
|
global_manage((state, prune_cutoff))
|
|
|
|
winner = get_winner(state)
|
|
if winner == target:
|
|
return 1
|
|
if winner == "draw":
|
|
return 0
|
|
if winner is not None:
|
|
# winner must be the opponent
|
|
return -1
|
|
|
|
agg, agg_prune = (smax, smin) if state.player == target else (smin, smax)
|
|
|
|
import random
|
|
|
|
# _next_states = get_next_states(state)
|
|
# first_state, *next_states = tuple(random.sample(_next_states, len(_next_states)))
|
|
first_state, *next_states = get_next_states(state)
|
|
|
|
score = get_score_ab_prune(target, first_state, None)
|
|
assert score != "pruned"
|
|
if (
|
|
prune_cutoff is not None
|
|
and score != prune_cutoff
|
|
and agg_prune(score, prune_cutoff) == prune_cutoff
|
|
):
|
|
return "pruned"
|
|
for next_state in next_states:
|
|
next_score = get_score_ab_prune(target, next_state, score)
|
|
if next_score == "pruned":
|
|
continue
|
|
score = agg(score, next_score)
|
|
if (
|
|
prune_cutoff is not None
|
|
and score != prune_cutoff
|
|
and agg_prune(score, prune_cutoff) == prune_cutoff
|
|
):
|
|
return "pruned"
|
|
return score
|
|
|
|
|
|
# total_nodes = 0
|
|
# nodes = []
|
|
|
|
|
|
def start_ab(p, b):
|
|
return get_score_ab_prune(p, b, None)
|
|
|
|
|
|
def start_naive(p, b):
|
|
return get_score(p, b)
|
|
|
|
|
|
total_nodes = 0
|
|
|
|
|
|
def manage_ab(data):
|
|
global total_nodes
|
|
state, prune_cutoff = data
|
|
total_nodes += 1
|
|
# print(str(state))
|
|
# # print(repr(state))
|
|
# print(f"{prune_cutoff=}, {state.player=}")
|
|
# print()
|
|
|
|
|
|
def manage_naive(data):
|
|
global total_nodes
|
|
state = data
|
|
total_nodes += 1
|
|
# print(str(state))
|
|
# print(repr(state))
|
|
# print()
|
|
|
|
|
|
REAL = GameState(player="X", board=(" ",) * 9)
|
|
X_WON = GameState(player="X", board=("X", "X", " ", "O", "O", " ", "O", "O", " "))
|
|
C = GameState(player="X", board=("X", "X", " ", "O", " ", " ", "O", " ", " "))
|
|
board = REAL
|
|
# global_manage, get_score_func = manage_naive, start_naive
|
|
global_manage, get_score_func = manage_ab, start_ab
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# real: total_nodes=5478
|
|
# x_won: total_nodes=8
|
|
print(f"{get_score_func('X', board)=}")
|
|
print(f"{total_nodes=}")
|
|
print(f"{memo_misses=}")
|
|
|
|
# real: total_nodes=9896 (w/o custom cache) / 8503
|
|
# x_won: total_nodes=6
|
|
# print(f"{get_score_ab_prune('X', X_WON, None)=}")
|