games/tictactoe.py
2023-06-01 01:24:45 -07:00

136 lines
3.0 KiB
Python

from __future__ import annotations
import typing as tp
from dataclasses import dataclass
from functools import cache, cached_property
Player = tp.Literal["X", "O"]
Draw = tp.Literal["-"]
Square = tp.Literal["X", "O", " "]
Board = tuple[Square, Square, Square, Square, Square, Square, Square, Square, Square]
@dataclass(frozen=True)
class GameState:
player: Player
board: Board
def __str__(self):
b = self.board
return (
f" {b[0]}{b[1]}{b[2]}\n"
"───┼───┼───\n"
f" {b[3]}{b[4]}{b[5]}\n"
"───┼───┼───\n"
f" {b[6]}{b[7]}{b[8]}"
)
@dataclass(frozen=True)
class Move:
player: Player
position: int
@cache
def get_valid_moves(state: GameState) -> tuple[Move]:
return tuple(
Move(state.player, position=i)
for i, square in enumerate(state.board)
if square == " "
)
@cache
def apply_move(state: GameState, move: Move) -> GameState:
new_player = "X" if move.player == "O" else "O"
new_board = list(state.board)
new_board[move.position] = move.player
new_board = tuple(new_board)
return GameState(player=new_player, board=new_board)
@cache
def get_winner(state: GameState) -> Player | Draw | None:
# abc
# def
# ghi
a, b, c, d, e, f, g, h, i = state.board
# horizontal
if a == b == c and a != " ":
return a
if d == e == f and d != " ":
return d
if g == h == i and g != " ":
return g
# vertical
if a == d == g and a != " ":
return a
if b == e == h and b != " ":
return b
if c == f == i and c != " ":
return c
# diagonal
if a == e == i and a != " ":
return a
if c == e == g and c != " ":
return c
# draw
if not any(square == " " for square in state.board):
return "-"
# no winner
return None
# monte-carlo style tree search
START = GameState(player="X", board=(" ",) * 9)
# with dynamic programming B)
total_nodes = 0
@cache
def get_next_states(state: GameState) -> tuple[GameState, ...]:
if get_winner(state) is not None:
return tuple()
return tuple(apply_move(state, move) for move in get_valid_moves(state))
@cache
def get_score(state: GameState, target: Player) -> float:
global total_nodes
total_nodes += 1
winner = get_winner(state)
if winner == target:
return float("inf")
if winner == "-":
return 0
if winner is not None:
# winner must be the opponent
return float("-inf")
next_scores = (
get_score(next_state, target) for next_state in get_next_states(state)
)
if state.player == target:
# target player plays at this node
return max(next_scores)
else:
# opponent plays at this node
return min(next_scores)
if __name__ == "__main__":
# TODO: "get best move"
print(f"{START=}, {get_score(START, 'X')=}")
print(f"{total_nodes=}") # 5478