from __future__ import annotations import typing as tp from dataclasses import dataclass from functools import cache, cached_property Player = tp.Literal["X", "O"] Draw = tp.Literal["-"] Square = tp.Literal["X", "O", " "] Board = tuple[Square, Square, Square, Square, Square, Square, Square, Square, Square] @dataclass(frozen=True) class GameState: player: Player board: Board def __str__(self): b = self.board return ( f" {b[0]} │ {b[1]} │ {b[2]}\n" "───┼───┼───\n" f" {b[3]} │ {b[4]} │ {b[5]}\n" "───┼───┼───\n" f" {b[6]} │ {b[7]} │ {b[8]}" ) @dataclass(frozen=True) class Move: player: Player position: int def get_valid_moves(state: GameState) -> tuple[Move]: return tuple( Move(state.player, position=i) for i, square in enumerate(state.board) if square == " " ) @cache def apply_move(state: GameState, move: Move) -> GameState: new_player = "X" if move.player == "O" else "O" new_board = list(state.board) new_board[move.position] = move.player new_board = tuple(new_board) return GameState(player=new_player, board=new_board) def get_winner(state: GameState) -> Player | Draw | None: # abc # def # ghi a, b, c, d, e, f, g, h, i = state.board # horizontal if a == b == c and a != " ": return a if d == e == f and d != " ": return d if g == h == i and g != " ": return g # vertical if a == d == g and a != " ": return a if b == e == h and b != " ": return b if c == f == i and c != " ": return c # diagonal if a == e == i and a != " ": return a if c == e == g and c != " ": return c # draw if not any(square == " " for square in state.board): return "-" # no winner return None # monte-carlo style tree search START = GameState(player="X", board=(" ",) * 9) total_nodes = 0 @dataclass(frozen=True) class Node: state: GameState = START depth: int = 0 def __post_init__(self): global total_nodes total_nodes += 1 @cached_property def winner(self) -> Player | Draw | None: return get_winner(self.state) # TODO: also store move so that we can use it as a "play" # for human-interaction @cached_property def next(self) -> tuple[Node]: if self.winner is not None: return tuple() return tuple( Node(state=apply_move(self.state, move), depth=self.depth + 1) for move in get_valid_moves(self.state) ) # TODO: max depth def score(self, target: Player) -> float: if self.winner == target: return float("inf") if self.winner == "-": return 0 if self.winner is not None: return float("-inf") # alternate min/max by-turn next_scores = (node.score(target) for node in self.next) if self.state.player == target: # target plays at this node return max(next_scores) else: # opponent plays at this node return min(next_scores) def count_nodes(root: Node) -> int: return 1 + sum(count_nodes(node) for node in root.next) if __name__ == "__main__": root = Node() print(f"{total_nodes=}") print(f"{root=}, {root.score('X')=}") print(f"{count_nodes(root)=}") print(f"{total_nodes=}") # TODO: total nodes (549946) seems way too high given this post: # https://stackoverflow.com/questions/7466429/generate-a-list-of-all-unique-tic-tac-toe-boards # probably could do with some de-duplication of already-visited nodes?