Chapter 5.1: Monte Carlo Tree Search

This will be a preliminary tutorial for the next one on AlphaZero which relies heavily on a modified version of Monte Carlo Tree search for board games. MCTS builds a search tree incrementally via random simulation (hence Monte Carlo). This random simulation is done through rollouts. In each rollout the game is played through to the end by selecting random moves (or some other heuristic but we will use random). The goal of this search is to explore the tree and for each node we discover keep track of a set of variables that will guide our descion function at our node so we can choose a move that is most likely to result in a winning game.

For each node we've explored during our random sampling we keep track of several variables. The first is N(S), the number of visits to node s (how many times have we encountered node s during a random sample). The second is N(s,a), the number of times action \(a\) was taken at node s, for example in connect four it may be how many times we chose to drop a piece in column \(a\). We also have Q(s,a), the total reward from simulations when action \(a\) was taken at node s. This total comes form the rewards at the leaf nodes (game end, player win, A.I. win or draw). The rewards are as follows 1 for win, 0 for draw and -1 for lose. For each rollout we sum up all the rewards from the leaf nodes we discovered during rollouts that are connected to node s. We also keep track of \(\bar{X}(s,a)=Q(s,a)/N(s,a)\), i.e. the average reward. We can easily keep track of all this by creating a class that keeps record of these values for each node.


class MCTSNode:
    def __init__(self, state, parent=None, move=None):
        self.state = state
        self.parent = parent
        self.move = move
        self.children = []
        self.visits = 0
        self.wins = 0
        self.untried_moves = state.get_legal_moves()
  

Starting from some initial game state we perform the exploration as such: Traverse the tree via UCT (upper confidence bound applied to trees), the formula is given below (the formula inside the optimisation problem) (\(a^{*}\) is the optimal move to make based on the current variables we have from our exploration of the game tree at our current node (current game state)). This traversal is done until we reach a node with unexplored actions (N(s,a) is 0 for all actions). We then add one new child node from the untried moves (randomly). We then play out a random game from that node (each player takes random moves until game end (i.e. a rollout)). Finally we propagate up all the new counts and rewards from the nodes we explored during the rollout. One thing to note is that if we start at the top of the game tree we can start doing some initial exploration by doing a rollout for each child node. This is so we can get a decent initial evaluation from our UCT function since when \(N(s,a)=0\) for all child nodes UCT is undefined (divide by 0 error).

\(a^{*}=\arg\max_{a}\left[\bar{X}(s,a)+C\cdot\sqrt{\frac{\ln N(s)}{N(s,a)}}\right]\)

Here C is the exploration parameter that controls how much we want exploration vs exploitation (theoretically is \(\sqrt{2}\) but is normally chosen emperically). Some other things to note is that it is a sum of exploitation (\(\bar{X}(s,a)\) (knowledge from what we already know)) and exploitation (\(C\cdot\sqrt{\frac{\ln N(s)}{N(s,a)}}\)). So our function makes a decision based on knowns and also encourages some exploration.

The algorithm is generally described in 4 steps (selection, expansion, simulation, backpropagation (not to be confused with backpropagation from neural networks, we aren't taking gradients)). Selection is when we apply the UCT formula at each succesive round of child nodes to find the best path. We stop when we reach a layer of the tree with a child with which \(N(s,a)=0\). We then proceed to expansion where unless we reach a game over we choose a child node. Then we reach simulation where we perform a rollout from that child node. We then do backpropagation to update the values of the important variables of all the nodes we explored. We now expand our Node class to include functions to check for untried actions at a node (is_fully_expanded()), calculate the optimum value of the UCT at a node (best_child(c=1.4)) (note 1.4 is close to \(\sqrt{2}\)), return a child node that was selected (expand()), update rewards and visits for nodes for the purpose of backpropagation update(self, reward)) and add a function to perform a random rollout.


class Node:
    def __init__(self, state, parent=None):
        self.state = state
        self.parent = parent
        self.children = {}  # action -> Node
        self.visits = 0
        self.total_reward = 0.0
        self.untried_actions = state.get_legal_actions()  # depends on game

    def is_fully_expanded(self):
        return len(self.untried_actions) == 0

    def best_child(self, c=1.4):
        # Select child with highest UCT
        def uct_value(child):
            if child.visits == 0:
                return float('inf')  # Explore unvisited child first
            exploit = child.total_reward / child.visits
            explore = math.sqrt(math.log(self.visits) / child.visits)
            return exploit + c * explore

        return max(self.children.values(), key=uct_value)

    def expand(self):
        action = self.untried_actions.pop()
        next_state = self.state.move(action)
        child_node = Node(next_state, parent=self)
        self.children[action] = child_node
        return child_node

    def update(self, reward):
        self.visits += 1
        self.total_reward += reward


def rollout_policy(state):
    # Default: random rollout
    while not state.is_terminal():
        action = random.choice(state.get_legal_actions())
        state = state.move(action)
    return state.get_result()  # e.g., -1, 0, 1

  

Our final function implements the 4 steps we discussed earlier using the function we defined above. It also performs an initial exploration if all the actions at the current node are unexplored. The iterations here refer to how much sampling we are doing. The model performs better the more we sample as it will have a better idea of the game tree.


def mcts(root_state, iterations=1000):
    root = Node(root_state)

    # --- Initial exploration of all actions at the root ---
    while root.untried_actions:
        node = root.expand()
        reward = rollout_policy(node.state)
        while node is not None:
            node.update(reward)
            node = node.parent
    # -------------------------------------------------------

    for _ in range(iterations):
        node = root

        # 1. Selection
        while node.is_fully_expanded() and not node.state.is_terminal():
            node = node.best_child()

        # 2. Expansion
        if not node.state.is_terminal():
            node = node.expand()

        # 3. Simulation (Rollout)
        reward = rollout_policy(node.state)

        # 4. Backpropagation
        while node is not None:
            node.update(reward)
            node = node.parent

    # Final action selection (e.g., most visited child)
    return max(root.children.items(), key=lambda item: item[1].visits)[0]
  

Below I'll link some files you can use to play against MCTS in connect four on your computer. I won't go through the implementation of connect four but I'll note we have an important function in the connect four file that converts the state of the game to a usable form for the MCTS we just implemented. To run the script from the download below just have mcts.py and connectfour.py in the same folder and on your terminal when in that folder run the command python3 connectfour.py. Some final advice is to set the iterations to at least 20000 if you want to play against a competitive agent.

🔽 Download MCTS Files

Click below to download the Python files needed to play against the MCTS agent in Connect Four.

Download mcts.py Download connectfour.py