move ab pruning for tic tac toe to archive, not nearly as elegant
This commit is contained in:
parent
147b3de441
commit
607ecef0eb
0
archive/tictactoe_with_ab_prune.py
Normal file
0
archive/tictactoe_with_ab_prune.py
Normal file
148
tictactoe.py
148
tictactoe.py
@ -101,7 +101,8 @@ Score = int
|
||||
|
||||
@cache
|
||||
def get_score(target: Player, state: GameState) -> Score:
|
||||
global_manage(state)
|
||||
global total_nodes
|
||||
total_nodes += 1
|
||||
|
||||
winner = get_winner(state)
|
||||
if winner == target:
|
||||
@ -117,148 +118,9 @@ def get_score(target: Player, state: GameState) -> Score:
|
||||
return score
|
||||
|
||||
|
||||
ScoreOrPruned = Score | tp.Literal["pruned"]
|
||||
|
||||
ScoreAgg = tp.Callable[[Score, Score], Score]
|
||||
ScoreOrPruneAgg = tp.Callable[[Score, ScoreOrPruned], Score]
|
||||
|
||||
|
||||
def smax(a: Score, b: Score) -> Score:
|
||||
return max(a, b)
|
||||
|
||||
|
||||
def smin(a: Score, b: Score) -> Score:
|
||||
return min(a, b)
|
||||
|
||||
|
||||
ScoreABPruneCallable = tp.Callable[[Player, GameState, Score | None], ScoreOrPruned]
|
||||
|
||||
|
||||
memo_misses = 0
|
||||
|
||||
|
||||
def ab_prune_cache(func: ScoreABPruneCallable) -> ScoreABPruneCallable:
|
||||
# memoization maps player, gamestate -> resulting score (or pruned), prune cutoff
|
||||
memo: dict[tuple[Player, GameState], tuple[ScoreOrPruned, Score | None]] = {}
|
||||
|
||||
# TODO: this could be further improved by starting at the pruned score and skipping
|
||||
def cached(
|
||||
target: Player, state: GameState, prune_cutoff: Score | None
|
||||
) -> ScoreOrPruned:
|
||||
if (target, state) in memo:
|
||||
memo_score, memo_prune_cutoff = memo[(target, state)]
|
||||
agg_prune = smin if state.player == target else smax
|
||||
if memo_prune_cutoff is None or (
|
||||
prune_cutoff is not None
|
||||
and agg_prune(prune_cutoff, memo_prune_cutoff) == prune_cutoff
|
||||
):
|
||||
# breakpoint() # michael
|
||||
return memo_score
|
||||
if (target, state) in memo:
|
||||
global memo_misses
|
||||
memo_misses += 1
|
||||
# breakpoint()
|
||||
score = func(target, state, prune_cutoff)
|
||||
memo[(target, state)] = (score, prune_cutoff)
|
||||
return score
|
||||
|
||||
return cached
|
||||
|
||||
|
||||
@ab_prune_cache
|
||||
def get_score_ab_prune(
|
||||
target: Player, state: GameState, prune_cutoff: Score | None
|
||||
) -> ScoreOrPruned:
|
||||
global_manage((state, prune_cutoff))
|
||||
|
||||
winner = get_winner(state)
|
||||
if winner == target:
|
||||
return 1
|
||||
if winner == "draw":
|
||||
return 0
|
||||
if winner is not None:
|
||||
# winner must be the opponent
|
||||
return -1
|
||||
|
||||
agg, agg_prune = (smax, smin) if state.player == target else (smin, smax)
|
||||
|
||||
import random
|
||||
|
||||
# _next_states = get_next_states(state)
|
||||
# first_state, *next_states = tuple(random.sample(_next_states, len(_next_states)))
|
||||
first_state, *next_states = get_next_states(state)
|
||||
|
||||
score = get_score_ab_prune(target, first_state, None)
|
||||
assert score != "pruned"
|
||||
if (
|
||||
prune_cutoff is not None
|
||||
and score != prune_cutoff
|
||||
and agg_prune(score, prune_cutoff) == prune_cutoff
|
||||
):
|
||||
return "pruned"
|
||||
for next_state in next_states:
|
||||
next_score = get_score_ab_prune(target, next_state, score)
|
||||
if next_score == "pruned":
|
||||
continue
|
||||
score = agg(score, next_score)
|
||||
if (
|
||||
prune_cutoff is not None
|
||||
and score != prune_cutoff
|
||||
and agg_prune(score, prune_cutoff) == prune_cutoff
|
||||
):
|
||||
return "pruned"
|
||||
return score
|
||||
|
||||
|
||||
# total_nodes = 0
|
||||
# nodes = []
|
||||
|
||||
|
||||
def start_ab(p, b):
|
||||
return get_score_ab_prune(p, b, None)
|
||||
|
||||
|
||||
def start_naive(p, b):
|
||||
return get_score(p, b)
|
||||
|
||||
|
||||
REAL = GameState(player="X", board=(" ",) * 9)
|
||||
total_nodes = 0
|
||||
|
||||
|
||||
def manage_ab(data):
|
||||
global total_nodes
|
||||
state, prune_cutoff = data
|
||||
total_nodes += 1
|
||||
# print(str(state))
|
||||
# # print(repr(state))
|
||||
# print(f"{prune_cutoff=}, {state.player=}")
|
||||
# print()
|
||||
|
||||
|
||||
def manage_naive(data):
|
||||
global total_nodes
|
||||
state = data
|
||||
total_nodes += 1
|
||||
# print(str(state))
|
||||
# print(repr(state))
|
||||
# print()
|
||||
|
||||
|
||||
REAL = GameState(player="X", board=(" ",) * 9)
|
||||
X_WON = GameState(player="X", board=("X", "X", " ", "O", "O", " ", "O", "O", " "))
|
||||
C = GameState(player="X", board=("X", "X", " ", "O", " ", " ", "O", " ", " "))
|
||||
board = REAL
|
||||
# global_manage, get_score_func = manage_naive, start_naive
|
||||
global_manage, get_score_func = manage_ab, start_ab
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# real: total_nodes=5478
|
||||
# x_won: total_nodes=8
|
||||
print(f"{get_score_func('X', board)=}")
|
||||
print(f"{total_nodes=}")
|
||||
print(f"{memo_misses=}")
|
||||
|
||||
# real: total_nodes=9896 (w/o custom cache) / 8503
|
||||
# x_won: total_nodes=6
|
||||
# print(f"{get_score_ab_prune('X', X_WON, None)=}")
|
||||
# total_nodes=5478
|
||||
print(f"{get_score('X', REAL)=}, {total_nodes=}")
|
||||
|
Loading…
Reference in New Issue
Block a user