diff --git a/.gitignore b/.gitignore index f7275bb..93526df 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ venv/ +__pycache__/ diff --git a/play.py b/play.py new file mode 100644 index 0000000..3d94acd --- /dev/null +++ b/play.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import random +import typing as tp +from abc import ABC, abstractmethod +from operator import itemgetter + + +class QuitException(Exception): + pass + + +class GameState(tp.Protocol): + player: tp.Any + + +class Move(tp.Protocol): + pass + + +class Player(tp.Protocol): + pass + + +Score = int + + +class GameModule(tp.Protocol): + @staticmethod + def get_valid_moves(state: GameState) -> tp.Iterable[Move]: + ... + + @staticmethod + def apply_move(state: GameState, move: Move) -> GameState: + ... + + @staticmethod + def get_winner(state: GameState) -> Player | tp.Literal["draw"] | None: + ... + + @staticmethod + def get_next_states( + state: GameState, + ) -> tuple[tuple[Move, GameState], ...]: + ... + + @staticmethod + def get_score(target: Player, state: GameState) -> Player: + ... + + @staticmethod + def get_scored_moves( + target: Player, state: GameState + ) -> tuple[tuple[Score, Move], ...]: + ... + + @staticmethod + def get_human_move(state: GameState) -> Move: + ... + + +def _get_ai_move(mod: GameModule, state: GameState) -> Move: + scored_moves = list(mod.get_scored_moves(state.player, state)) + random.shuffle(scored_moves) + _, best_move = max(scored_moves, key=itemgetter(0)) + return best_move + + +def play_human(mod: GameModule, human_player: Player, state: GameState) -> None: + try: + while not (winner := mod.get_winner(state)): + print(state) + if state.player == human_player: + move = mod.get_human_move(state) + else: + move = _get_ai_move(mod, state) + print(f"{state.player} plays {move}") + print() + state = mod.apply_move(state, move) + print(state) + if winner == "draw": + print("draw") + else: + print(f"{winner} wins") + except QuitException: + print("quit") + + +if __name__ == "__main__": + import tictactoe as ttt + + play_human(ttt, "X", ttt.REAL) diff --git a/tictactoe.py b/tictactoe.py index dbe1f42..f565e76 100644 --- a/tictactoe.py +++ b/tictactoe.py @@ -29,7 +29,7 @@ class QuitException(Exception): @dataclass(frozen=True) -class GameState: +class TTTGameState: player: Player board: Board @@ -53,25 +53,25 @@ class Move: # @cache -def get_valid_moves(state: GameState) -> tp.Iterable[Move]: +def get_valid_moves(state: TTTGameState) -> tp.Iterable[Move]: return tuple( Move(position=i) for i, square in enumerate(state.board) if square == " " ) # @cache -def apply_move(state: GameState, move: Move) -> GameState: +def apply_move(state: TTTGameState, move: Move) -> TTTGameState: new_player = "X" if state.player == "O" else "O" new_board = list(state.board) new_board[move.position] = state.player new_board = tuple(new_board) - return GameState(player=new_player, board=new_board) + return TTTGameState(player=new_player, board=new_board) # @cache -def get_winner(state: GameState) -> Winner | None: +def get_winner(state: TTTGameState) -> Winner | None: # abc # def # ghi @@ -108,14 +108,14 @@ def get_winner(state: GameState) -> Winner | None: # @cache -def get_next_states(state: GameState) -> tuple[tuple[Move, GameState], ...]: +def get_next_states(state: TTTGameState) -> tuple[tuple[Move, TTTGameState], ...]: assert get_winner(state) is None, "should not be called if game ended" return tuple((move, apply_move(state, move)) for move in get_valid_moves(state)) # @cache @util.count_calls -def get_score(target: Player, state: GameState) -> Score: +def get_score(target: Player, state: TTTGameState) -> Score: winner = get_winner(state) if winner == target: return 1 @@ -134,14 +134,14 @@ def get_score(target: Player, state: GameState) -> Score: # @cache def get_scored_moves( - target: Player, state: GameState + target: Player, state: TTTGameState ) -> tuple[tuple[Score, Move], ...]: return tuple( (get_score(target, state), move) for move, state in get_next_states(state) ) -def get_human_move(state: GameState) -> Move: +def get_human_move(state: TTTGameState) -> Move: print("123\n456\n789") while True: choice = input("choice: ") @@ -159,7 +159,7 @@ def get_human_move(state: GameState) -> Move: return move -REAL = GameState(player="X", board=(" ",) * 9) +REAL = TTTGameState(player="X", board=(" ",) * 9) # if __name__ == "__main__": # # total_nodes=5478 diff --git a/util.py b/util.py new file mode 100644 index 0000000..fa06422 --- /dev/null +++ b/util.py @@ -0,0 +1,15 @@ +import typing as tp + +CALL_COUNTS = {} +T = tp.TypeVar("T") +P = tp.ParamSpec("P") + + +def count_calls(f: tp.Callable[P, T]) -> tp.Callable[P, T]: + # TODO: this isn't giving accurate results with the play_human? + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + cc = CALL_COUNTS.setdefault(f.__name__, 0) + CALL_COUNTS[f.__name__] = cc + 1 + return f(*args, **kwargs) + + return wrapper