2023-06-01 07:58:45 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import typing as tp
|
|
|
|
from dataclasses import dataclass
|
2023-06-01 21:22:42 +00:00
|
|
|
from functools import cache
|
2023-06-01 07:58:45 +00:00
|
|
|
|
|
|
|
Player = tp.Literal["X", "O"]
|
2023-06-01 21:22:42 +00:00
|
|
|
Winner = Player | tp.Literal["-"]
|
|
|
|
Square = Player | tp.Literal[" "]
|
2023-06-01 07:58:45 +00:00
|
|
|
Board = tuple[Square, Square, Square, Square, Square, Square, Square, Square, Square]
|
2023-06-01 21:22:42 +00:00
|
|
|
# TODO: make this nicer for printing
|
|
|
|
Score = int
|
2023-06-01 07:58:45 +00:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
class GameState:
|
|
|
|
player: Player
|
|
|
|
board: Board
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
b = self.board
|
|
|
|
return (
|
|
|
|
f" {b[0]} │ {b[1]} │ {b[2]}\n"
|
|
|
|
"───┼───┼───\n"
|
|
|
|
f" {b[3]} │ {b[4]} │ {b[5]}\n"
|
|
|
|
"───┼───┼───\n"
|
|
|
|
f" {b[6]} │ {b[7]} │ {b[8]}"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
class Move:
|
|
|
|
position: int
|
|
|
|
|
|
|
|
|
2023-06-01 08:24:45 +00:00
|
|
|
@cache
|
2023-06-01 07:58:45 +00:00
|
|
|
def get_valid_moves(state: GameState) -> tuple[Move]:
|
|
|
|
return tuple(
|
2023-06-01 21:22:42 +00:00
|
|
|
Move(position=i)
|
2023-06-01 07:58:45 +00:00
|
|
|
for i, square in enumerate(state.board)
|
|
|
|
if square == " "
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@cache
|
|
|
|
def apply_move(state: GameState, move: Move) -> GameState:
|
2023-06-01 21:22:42 +00:00
|
|
|
new_player = "X" if state.player == "O" else "O"
|
2023-06-01 07:58:45 +00:00
|
|
|
|
|
|
|
new_board = list(state.board)
|
2023-06-01 21:22:42 +00:00
|
|
|
new_board[move.position] = state.player
|
2023-06-01 07:58:45 +00:00
|
|
|
new_board = tuple(new_board)
|
|
|
|
|
|
|
|
return GameState(player=new_player, board=new_board)
|
|
|
|
|
|
|
|
|
2023-06-01 08:17:35 +00:00
|
|
|
@cache
|
2023-06-01 21:22:42 +00:00
|
|
|
def get_winner(state: GameState) -> Winner | None:
|
2023-06-01 07:58:45 +00:00
|
|
|
# abc
|
|
|
|
# def
|
|
|
|
# ghi
|
|
|
|
a, b, c, d, e, f, g, h, i = state.board
|
|
|
|
|
|
|
|
# horizontal
|
|
|
|
if a == b == c and a != " ":
|
|
|
|
return a
|
|
|
|
if d == e == f and d != " ":
|
|
|
|
return d
|
|
|
|
if g == h == i and g != " ":
|
|
|
|
return g
|
|
|
|
|
|
|
|
# vertical
|
|
|
|
if a == d == g and a != " ":
|
|
|
|
return a
|
|
|
|
if b == e == h and b != " ":
|
|
|
|
return b
|
|
|
|
if c == f == i and c != " ":
|
|
|
|
return c
|
|
|
|
|
|
|
|
# diagonal
|
|
|
|
if a == e == i and a != " ":
|
|
|
|
return a
|
|
|
|
if c == e == g and c != " ":
|
|
|
|
return c
|
|
|
|
|
|
|
|
# draw
|
|
|
|
if not any(square == " " for square in state.board):
|
|
|
|
return "-"
|
|
|
|
|
|
|
|
# no winner
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# monte-carlo style tree search
|
|
|
|
|
|
|
|
START = GameState(player="X", board=(" ",) * 9)
|
|
|
|
|
2023-06-01 08:19:52 +00:00
|
|
|
# with dynamic programming B)
|
|
|
|
total_nodes = 0
|
2023-06-01 07:58:45 +00:00
|
|
|
|
|
|
|
|
2023-06-01 08:19:52 +00:00
|
|
|
@cache
|
2023-06-01 08:24:45 +00:00
|
|
|
def get_next_states(state: GameState) -> tuple[GameState, ...]:
|
2023-06-01 08:19:52 +00:00
|
|
|
if get_winner(state) is not None:
|
|
|
|
return tuple()
|
|
|
|
return tuple(apply_move(state, move) for move in get_valid_moves(state))
|
2023-06-01 07:58:45 +00:00
|
|
|
|
|
|
|
|
2023-06-01 08:19:52 +00:00
|
|
|
@cache
|
2023-06-01 21:22:42 +00:00
|
|
|
def get_score(state: GameState, target: Player) -> Score:
|
2023-06-01 08:19:52 +00:00
|
|
|
global total_nodes
|
|
|
|
total_nodes += 1
|
2023-06-01 08:22:37 +00:00
|
|
|
|
2023-06-01 08:19:52 +00:00
|
|
|
winner = get_winner(state)
|
|
|
|
if winner == target:
|
2023-06-01 21:22:42 +00:00
|
|
|
return 1
|
2023-06-01 08:19:52 +00:00
|
|
|
if winner == "-":
|
|
|
|
return 0
|
2023-06-01 08:22:37 +00:00
|
|
|
if winner is not None:
|
|
|
|
# winner must be the opponent
|
2023-06-01 21:22:42 +00:00
|
|
|
return -1
|
2023-06-01 08:19:52 +00:00
|
|
|
|
|
|
|
next_scores = (
|
|
|
|
get_score(next_state, target) for next_state in get_next_states(state)
|
|
|
|
)
|
|
|
|
if state.player == target:
|
|
|
|
# target player plays at this node
|
|
|
|
return max(next_scores)
|
|
|
|
else:
|
|
|
|
# opponent plays at this node
|
|
|
|
return min(next_scores)
|
2023-06-01 08:05:25 +00:00
|
|
|
|
|
|
|
|
2023-06-01 07:58:45 +00:00
|
|
|
if __name__ == "__main__":
|
2023-06-01 08:19:52 +00:00
|
|
|
# TODO: "get best move"
|
2023-06-02 02:57:35 +00:00
|
|
|
# TODO: rotational / reflectional board parity (for less total nodes)
|
2023-06-01 08:19:52 +00:00
|
|
|
print(f"{START=}, {get_score(START, 'X')=}")
|
2023-06-01 08:22:37 +00:00
|
|
|
print(f"{total_nodes=}") # 5478
|