from __future__ import annotations import typing as tp from dataclasses import dataclass from functools import cache Player = tp.Literal["X", "O"] Winner = Player | tp.Literal["draw"] Square = Player | tp.Literal[" "] Board = tuple[Square, Square, Square, Square, Square, Square, Square, Square, Square] # TODO: "get best move" # TODO: alpha-beta pruning # TODO: rotational / reflectional board parity (for less total nodes) @dataclass(frozen=True) class GameState: player: Player board: Board def __str__(self): b = self.board return ( f" {b[0]} │ {b[1]} │ {b[2]}\n" "───┼───┼───\n" f" {b[3]} │ {b[4]} │ {b[5]}\n" "───┼───┼───\n" f" {b[6]} │ {b[7]} │ {b[8]}" ) @dataclass(frozen=True) class Move: position: int @cache def get_valid_moves(state: GameState) -> tuple[Move]: return tuple( Move(position=i) for i, square in enumerate(state.board) if square == " " ) @cache def apply_move(state: GameState, move: Move) -> GameState: new_player = "X" if state.player == "O" else "O" new_board = list(state.board) new_board[move.position] = state.player new_board = tuple(new_board) return GameState(player=new_player, board=new_board) @cache def get_winner(state: GameState) -> Winner | None: # abc # def # ghi a, b, c, d, e, f, g, h, i = state.board # horizontal if a == b == c and a != " ": return a if d == e == f and d != " ": return d if g == h == i and g != " ": return g # vertical if a == d == g and a != " ": return a if b == e == h and b != " ": return b if c == f == i and c != " ": return c # diagonal if a == e == i and a != " ": return a if c == e == g and c != " ": return c # draw if not any(square == " " for square in state.board): return "draw" # no winner return None @cache def get_next_states(state: GameState) -> tuple[GameState, ...]: assert get_winner(state) is None, "should not be called if game ended" return tuple(apply_move(state, move) for move in get_valid_moves(state)) Score = int @cache def get_score(target: Player, state: GameState) -> Score: global_manage(state) winner = get_winner(state) if winner == target: return 1 if winner == "draw": return 0 if winner is not None: # winner must be the opponent return -1 agg = max if state.player == target else min score = agg(get_score(target, next_state) for next_state in get_next_states(state)) return score ScoreOrPruned = Score | tp.Literal["pruned"] ScoreAgg = tp.Callable[[Score, Score], Score] ScoreOrPruneAgg = tp.Callable[[Score, ScoreOrPruned], Score] def smax(a: Score, b: Score) -> Score: return max(a, b) def smin(a: Score, b: Score) -> Score: return min(a, b) ScoreABPruneCallable = tp.Callable[[Player, GameState, Score | None], ScoreOrPruned] memo_misses = 0 def ab_prune_cache(func: ScoreABPruneCallable) -> ScoreABPruneCallable: # memoization maps player, gamestate -> resulting score (or pruned), prune cutoff memo: dict[tuple[Player, GameState], tuple[ScoreOrPruned, Score | None]] = {} # TODO: this could be further improved by starting at the pruned score and skipping def cached( target: Player, state: GameState, prune_cutoff: Score | None ) -> ScoreOrPruned: if (target, state) in memo: memo_score, memo_prune_cutoff = memo[(target, state)] agg_prune = smin if state.player == target else smax if memo_prune_cutoff is None or ( prune_cutoff is not None and agg_prune(prune_cutoff, memo_prune_cutoff) == prune_cutoff ): # breakpoint() # michael return memo_score if (target, state) in memo: global memo_misses memo_misses += 1 # breakpoint() score = func(target, state, prune_cutoff) memo[(target, state)] = (score, prune_cutoff) return score return cached @ab_prune_cache def get_score_ab_prune( target: Player, state: GameState, prune_cutoff: Score | None ) -> ScoreOrPruned: global_manage((state, prune_cutoff)) winner = get_winner(state) if winner == target: return 1 if winner == "draw": return 0 if winner is not None: # winner must be the opponent return -1 agg, agg_prune = (smax, smin) if state.player == target else (smin, smax) import random # _next_states = get_next_states(state) # first_state, *next_states = tuple(random.sample(_next_states, len(_next_states))) first_state, *next_states = get_next_states(state) score = get_score_ab_prune(target, first_state, None) assert score != "pruned" if ( prune_cutoff is not None and score != prune_cutoff and agg_prune(score, prune_cutoff) == prune_cutoff ): return "pruned" for next_state in next_states: next_score = get_score_ab_prune(target, next_state, score) if next_score == "pruned": continue score = agg(score, next_score) if ( prune_cutoff is not None and score != prune_cutoff and agg_prune(score, prune_cutoff) == prune_cutoff ): return "pruned" return score # total_nodes = 0 # nodes = [] def start_ab(p, b): return get_score_ab_prune(p, b, None) def start_naive(p, b): return get_score(p, b) total_nodes = 0 def manage_ab(data): global total_nodes state, prune_cutoff = data total_nodes += 1 # print(str(state)) # # print(repr(state)) # print(f"{prune_cutoff=}, {state.player=}") # print() def manage_naive(data): global total_nodes state = data total_nodes += 1 # print(str(state)) # print(repr(state)) # print() REAL = GameState(player="X", board=(" ",) * 9) X_WON = GameState(player="X", board=("X", "X", " ", "O", "O", " ", "O", "O", " ")) C = GameState(player="X", board=("X", "X", " ", "O", " ", " ", "O", " ", " ")) board = REAL # global_manage, get_score_func = manage_naive, start_naive global_manage, get_score_func = manage_ab, start_ab if __name__ == "__main__": # real: total_nodes=5478 # x_won: total_nodes=8 print(f"{get_score_func('X', board)=}") print(f"{total_nodes=}") print(f"{memo_misses=}") # real: total_nodes=9896 (w/o custom cache) / 8503 # x_won: total_nodes=6 # print(f"{get_score_ab_prune('X', X_WON, None)=}")