211 lines
5.2 KiB
Python
211 lines
5.2 KiB
Python
from __future__ import annotations
|
|
|
|
import pdb
|
|
import random
|
|
import typing as tp
|
|
from dataclasses import dataclass
|
|
from functools import cache
|
|
|
|
import pdbp
|
|
|
|
if hasattr(pdb, "DefaultConfig"):
|
|
pdb.DefaultConfig.sticky_by_default = False # type:ignore
|
|
|
|
Player = tp.Literal["X", "O"]
|
|
Winner = Player | tp.Literal["draw"]
|
|
Square = Player | tp.Literal[" "]
|
|
Board = tuple[Square, Square, Square, Square, Square, Square, Square, Square, Square]
|
|
Score = int
|
|
|
|
# TODO: "get best move"
|
|
# TODO: alpha-beta pruning (elegantly)
|
|
# TODO: rotational / reflectional board parity (for less total nodes)
|
|
|
|
CALL_COUNTS = {}
|
|
T = tp.TypeVar("T")
|
|
P = tp.ParamSpec("P")
|
|
|
|
|
|
def count_calls(f: tp.Callable[P, T]) -> tp.Callable[P, T]:
|
|
# TODO: this isn't giving accurate results with the play_human?
|
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
cc = CALL_COUNTS.setdefault(f.__name__, 0)
|
|
CALL_COUNTS[f.__name__] = cc + 1
|
|
return f(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
|
|
class QuitException(Exception):
|
|
pass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GameState:
|
|
player: Player
|
|
board: Board
|
|
|
|
def __str__(self):
|
|
b = self.board
|
|
return (
|
|
f" {b[0]} │ {b[1]} │ {b[2]}\n"
|
|
"───┼───┼───\n"
|
|
f" {b[3]} │ {b[4]} │ {b[5]}\n"
|
|
"───┼───┼───\n"
|
|
f" {b[6]} │ {b[7]} │ {b[8]}"
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Move:
|
|
position: int
|
|
|
|
def __str__(self) -> str:
|
|
return str(self.position + 1)
|
|
|
|
|
|
@cache
|
|
def get_valid_moves(state: GameState) -> tuple[Move]:
|
|
return tuple(
|
|
Move(position=i) for i, square in enumerate(state.board) if square == " "
|
|
)
|
|
|
|
|
|
@cache
|
|
def apply_move(state: GameState, move: Move) -> GameState:
|
|
new_player = "X" if state.player == "O" else "O"
|
|
|
|
new_board = list(state.board)
|
|
new_board[move.position] = state.player
|
|
new_board = tuple(new_board)
|
|
|
|
return GameState(player=new_player, board=new_board)
|
|
|
|
|
|
@cache
|
|
def get_winner(state: GameState) -> Winner | None:
|
|
# abc
|
|
# def
|
|
# ghi
|
|
a, b, c, d, e, f, g, h, i = state.board
|
|
|
|
# horizontal
|
|
if a == b == c and a != " ":
|
|
return a
|
|
if d == e == f and d != " ":
|
|
return d
|
|
if g == h == i and g != " ":
|
|
return g
|
|
|
|
# vertical
|
|
if a == d == g and a != " ":
|
|
return a
|
|
if b == e == h and b != " ":
|
|
return b
|
|
if c == f == i and c != " ":
|
|
return c
|
|
|
|
# diagonal
|
|
if a == e == i and a != " ":
|
|
return a
|
|
if c == e == g and c != " ":
|
|
return c
|
|
|
|
# draw
|
|
if not any(square == " " for square in state.board):
|
|
return "draw"
|
|
|
|
# no winner
|
|
return None
|
|
|
|
|
|
@cache
|
|
def get_next_states(state: GameState) -> tuple[tuple[Move, GameState], ...]:
|
|
assert get_winner(state) is None, "should not be called if game ended"
|
|
return tuple((move, apply_move(state, move)) for move in get_valid_moves(state))
|
|
|
|
|
|
@cache
|
|
@count_calls
|
|
def get_score(target: Player, state: GameState) -> Score:
|
|
winner = get_winner(state)
|
|
if winner == target:
|
|
return 1
|
|
if winner == "draw":
|
|
return 0
|
|
if winner is not None:
|
|
# winner must be the opponent
|
|
return -1
|
|
|
|
agg = max if state.player == target else min
|
|
score = agg(
|
|
get_score(target, next_state) for _, next_state in get_next_states(state)
|
|
)
|
|
return score
|
|
|
|
|
|
@cache
|
|
def get_best_moves(state: GameState) -> tuple[Move, ...]:
|
|
options: tuple[tuple[Move, GameState], ...] = get_next_states(state)
|
|
scored = tuple(
|
|
(move, next_state, get_score(state.player, next_state))
|
|
for move, next_state in options
|
|
)
|
|
max_score = max(score for _, _, score in scored)
|
|
print(f"{max_score=}")
|
|
return tuple(move for move, _, score in scored if score == max_score)
|
|
|
|
|
|
def get_ai_move(state: GameState) -> Move:
|
|
best_moves = get_best_moves(state)
|
|
print(f"{best_moves=}")
|
|
return random.choice(best_moves)
|
|
|
|
|
|
def get_human_move(state: GameState) -> Move:
|
|
print("123\n456\n789")
|
|
while True:
|
|
choice = input("choice: ")
|
|
if choice.endswith("b"):
|
|
breakpoint()
|
|
choice = choice[:-1]
|
|
if choice == "":
|
|
continue
|
|
elif choice == "q":
|
|
raise QuitException
|
|
move = Move(position=int(choice) - 1)
|
|
if move not in get_valid_moves(state):
|
|
print("bad move")
|
|
continue
|
|
return move
|
|
|
|
|
|
def play_human(human_player: Player):
|
|
try:
|
|
state = REAL
|
|
while not (winner := get_winner(state)):
|
|
print(state)
|
|
if state.player == human_player:
|
|
move = get_human_move(state)
|
|
else:
|
|
move = get_ai_move(state)
|
|
print(f"{state.player} plays at {move}")
|
|
print()
|
|
state = apply_move(state, move)
|
|
print(state)
|
|
if winner == "draw":
|
|
print("draw")
|
|
else:
|
|
print(f"{winner} wins")
|
|
except QuitException:
|
|
print("quit")
|
|
|
|
|
|
REAL = GameState(player="X", board=(" ",) * 9)
|
|
|
|
if __name__ == "__main__":
|
|
# total_nodes=5478
|
|
# print(f"{get_score('X', REAL)=}, {CALL_COUNTS['get_score']=}")
|
|
# best_moves = get_best_moves(REAL)
|
|
play_human(human_player="X")
|