2023-06-01 07:58:45 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import typing as tp
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from functools import cache, cached_property
|
|
|
|
|
|
|
|
Player = tp.Literal["X", "O"]
|
|
|
|
Draw = tp.Literal["-"]
|
|
|
|
Square = tp.Literal["X", "O", " "]
|
|
|
|
Board = tuple[Square, Square, Square, Square, Square, Square, Square, Square, Square]
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
class GameState:
|
|
|
|
player: Player
|
|
|
|
board: Board
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
b = self.board
|
|
|
|
return (
|
|
|
|
f" {b[0]} │ {b[1]} │ {b[2]}\n"
|
|
|
|
"───┼───┼───\n"
|
|
|
|
f" {b[3]} │ {b[4]} │ {b[5]}\n"
|
|
|
|
"───┼───┼───\n"
|
|
|
|
f" {b[6]} │ {b[7]} │ {b[8]}"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
class Move:
|
|
|
|
player: Player
|
|
|
|
position: int
|
|
|
|
|
|
|
|
|
|
|
|
def get_valid_moves(state: GameState) -> tuple[Move]:
|
|
|
|
return tuple(
|
|
|
|
Move(state.player, position=i)
|
|
|
|
for i, square in enumerate(state.board)
|
|
|
|
if square == " "
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@cache
|
|
|
|
def apply_move(state: GameState, move: Move) -> GameState:
|
|
|
|
new_player = "X" if move.player == "O" else "O"
|
|
|
|
|
|
|
|
new_board = list(state.board)
|
|
|
|
new_board[move.position] = move.player
|
|
|
|
new_board = tuple(new_board)
|
|
|
|
|
|
|
|
return GameState(player=new_player, board=new_board)
|
|
|
|
|
|
|
|
|
|
|
|
def get_winner(state: GameState) -> Player | Draw | None:
|
|
|
|
# abc
|
|
|
|
# def
|
|
|
|
# ghi
|
|
|
|
a, b, c, d, e, f, g, h, i = state.board
|
|
|
|
|
|
|
|
# horizontal
|
|
|
|
if a == b == c and a != " ":
|
|
|
|
return a
|
|
|
|
if d == e == f and d != " ":
|
|
|
|
return d
|
|
|
|
if g == h == i and g != " ":
|
|
|
|
return g
|
|
|
|
|
|
|
|
# vertical
|
|
|
|
if a == d == g and a != " ":
|
|
|
|
return a
|
|
|
|
if b == e == h and b != " ":
|
|
|
|
return b
|
|
|
|
if c == f == i and c != " ":
|
|
|
|
return c
|
|
|
|
|
|
|
|
# diagonal
|
|
|
|
if a == e == i and a != " ":
|
|
|
|
return a
|
|
|
|
if c == e == g and c != " ":
|
|
|
|
return c
|
|
|
|
|
|
|
|
# draw
|
|
|
|
if not any(square == " " for square in state.board):
|
|
|
|
return "-"
|
|
|
|
|
|
|
|
# no winner
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# monte-carlo style tree search
|
|
|
|
|
|
|
|
START = GameState(player="X", board=(" ",) * 9)
|
|
|
|
|
|
|
|
total_nodes = 0
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
class Node:
|
|
|
|
state: GameState = START
|
|
|
|
depth: int = 0
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
global total_nodes
|
|
|
|
total_nodes += 1
|
|
|
|
|
|
|
|
@cached_property
|
|
|
|
def winner(self) -> Player | Draw | None:
|
|
|
|
return get_winner(self.state)
|
|
|
|
|
|
|
|
# TODO: also store move so that we can use it as a "play"
|
|
|
|
# for human-interaction
|
|
|
|
@cached_property
|
|
|
|
def next(self) -> tuple[Node]:
|
|
|
|
if self.winner is not None:
|
|
|
|
return tuple()
|
|
|
|
return tuple(
|
|
|
|
Node(state=apply_move(self.state, move), depth=self.depth + 1)
|
|
|
|
for move in get_valid_moves(self.state)
|
|
|
|
)
|
|
|
|
|
|
|
|
# TODO: max depth
|
|
|
|
def score(self, target: Player) -> float:
|
|
|
|
if self.winner == target:
|
|
|
|
return float("inf")
|
|
|
|
if self.winner == "-":
|
|
|
|
return 0
|
|
|
|
if self.winner is not None:
|
|
|
|
return float("-inf")
|
|
|
|
|
|
|
|
# alternate min/max by-turn
|
|
|
|
next_scores = (node.score(target) for node in self.next)
|
|
|
|
if self.state.player == target:
|
|
|
|
# target plays at this node
|
|
|
|
return max(next_scores)
|
|
|
|
else:
|
|
|
|
# opponent plays at this node
|
|
|
|
return min(next_scores)
|
|
|
|
|
|
|
|
|
2023-06-01 08:05:25 +00:00
|
|
|
def count_nodes(root: Node) -> int:
|
|
|
|
return 1 + sum(count_nodes(node) for node in root.next)
|
|
|
|
|
|
|
|
|
2023-06-01 07:58:45 +00:00
|
|
|
if __name__ == "__main__":
|
|
|
|
root = Node()
|
2023-06-01 08:05:25 +00:00
|
|
|
print(f"{total_nodes=}")
|
2023-06-01 07:58:45 +00:00
|
|
|
print(f"{root=}, {root.score('X')=}")
|
2023-06-01 08:05:25 +00:00
|
|
|
print(f"{count_nodes(root)=}")
|
|
|
|
print(f"{total_nodes=}")
|
|
|
|
# TODO: total nodes (549946) seems way too high given this post:
|
2023-06-01 07:58:45 +00:00
|
|
|
# https://stackoverflow.com/questions/7466429/generate-a-list-of-all-unique-tic-tac-toe-boards
|
2023-06-01 08:05:25 +00:00
|
|
|
# probably could do with some de-duplication of already-visited nodes?
|