From 607ecef0eb4298a758ba8296eefabdd70b9cdd7c Mon Sep 17 00:00:00 2001 From: Michael Peters Date: Thu, 1 Jun 2023 23:01:02 -0700 Subject: [PATCH] move ab pruning for tic tac toe to archive, not nearly as elegant --- archive/tictactoe_with_ab_prune.py | 0 tictactoe.py | 148 +---------------------------- 2 files changed, 5 insertions(+), 143 deletions(-) create mode 100644 archive/tictactoe_with_ab_prune.py diff --git a/archive/tictactoe_with_ab_prune.py b/archive/tictactoe_with_ab_prune.py new file mode 100644 index 0000000..e69de29 diff --git a/tictactoe.py b/tictactoe.py index 7be44bf..bcd5e9e 100644 --- a/tictactoe.py +++ b/tictactoe.py @@ -101,7 +101,8 @@ Score = int @cache def get_score(target: Player, state: GameState) -> Score: - global_manage(state) + global total_nodes + total_nodes += 1 winner = get_winner(state) if winner == target: @@ -117,148 +118,9 @@ def get_score(target: Player, state: GameState) -> Score: return score -ScoreOrPruned = Score | tp.Literal["pruned"] - -ScoreAgg = tp.Callable[[Score, Score], Score] -ScoreOrPruneAgg = tp.Callable[[Score, ScoreOrPruned], Score] - - -def smax(a: Score, b: Score) -> Score: - return max(a, b) - - -def smin(a: Score, b: Score) -> Score: - return min(a, b) - - -ScoreABPruneCallable = tp.Callable[[Player, GameState, Score | None], ScoreOrPruned] - - -memo_misses = 0 - - -def ab_prune_cache(func: ScoreABPruneCallable) -> ScoreABPruneCallable: - # memoization maps player, gamestate -> resulting score (or pruned), prune cutoff - memo: dict[tuple[Player, GameState], tuple[ScoreOrPruned, Score | None]] = {} - - # TODO: this could be further improved by starting at the pruned score and skipping - def cached( - target: Player, state: GameState, prune_cutoff: Score | None - ) -> ScoreOrPruned: - if (target, state) in memo: - memo_score, memo_prune_cutoff = memo[(target, state)] - agg_prune = smin if state.player == target else smax - if memo_prune_cutoff is None or ( - prune_cutoff is not None - and agg_prune(prune_cutoff, memo_prune_cutoff) == prune_cutoff - ): - # breakpoint() # michael - return memo_score - if (target, state) in memo: - global memo_misses - memo_misses += 1 - # breakpoint() - score = func(target, state, prune_cutoff) - memo[(target, state)] = (score, prune_cutoff) - return score - - return cached - - -@ab_prune_cache -def get_score_ab_prune( - target: Player, state: GameState, prune_cutoff: Score | None -) -> ScoreOrPruned: - global_manage((state, prune_cutoff)) - - winner = get_winner(state) - if winner == target: - return 1 - if winner == "draw": - return 0 - if winner is not None: - # winner must be the opponent - return -1 - - agg, agg_prune = (smax, smin) if state.player == target else (smin, smax) - - import random - - # _next_states = get_next_states(state) - # first_state, *next_states = tuple(random.sample(_next_states, len(_next_states))) - first_state, *next_states = get_next_states(state) - - score = get_score_ab_prune(target, first_state, None) - assert score != "pruned" - if ( - prune_cutoff is not None - and score != prune_cutoff - and agg_prune(score, prune_cutoff) == prune_cutoff - ): - return "pruned" - for next_state in next_states: - next_score = get_score_ab_prune(target, next_state, score) - if next_score == "pruned": - continue - score = agg(score, next_score) - if ( - prune_cutoff is not None - and score != prune_cutoff - and agg_prune(score, prune_cutoff) == prune_cutoff - ): - return "pruned" - return score - - -# total_nodes = 0 -# nodes = [] - - -def start_ab(p, b): - return get_score_ab_prune(p, b, None) - - -def start_naive(p, b): - return get_score(p, b) - - +REAL = GameState(player="X", board=(" ",) * 9) total_nodes = 0 - -def manage_ab(data): - global total_nodes - state, prune_cutoff = data - total_nodes += 1 - # print(str(state)) - # # print(repr(state)) - # print(f"{prune_cutoff=}, {state.player=}") - # print() - - -def manage_naive(data): - global total_nodes - state = data - total_nodes += 1 - # print(str(state)) - # print(repr(state)) - # print() - - -REAL = GameState(player="X", board=(" ",) * 9) -X_WON = GameState(player="X", board=("X", "X", " ", "O", "O", " ", "O", "O", " ")) -C = GameState(player="X", board=("X", "X", " ", "O", " ", " ", "O", " ", " ")) -board = REAL -# global_manage, get_score_func = manage_naive, start_naive -global_manage, get_score_func = manage_ab, start_ab - - if __name__ == "__main__": - # real: total_nodes=5478 - # x_won: total_nodes=8 - print(f"{get_score_func('X', board)=}") - print(f"{total_nodes=}") - print(f"{memo_misses=}") - - # real: total_nodes=9896 (w/o custom cache) / 8503 - # x_won: total_nodes=6 - # print(f"{get_score_ab_prune('X', X_WON, None)=}") + # total_nodes=5478 + print(f"{get_score('X', REAL)=}, {total_nodes=}")