Chapter 5.2: AlphaZero

One of the most impressive models in machine learning in recent years was Google Deepminds AlphaGo model which was able to beat the worlds top Go player. In this tutorial we will go over the basics of AlphaZero (similar to AlphaGo but trains against itself instead of past matches) and implement it for a much simpler game connect four. AlphaZero is essentially an extension of the MCTS algorithm we covered in the previous article. The only change in our tree search is that we use another function instead of UCT to evaluate how to traverse the tree. This function is called PUCT (predicter + upper confidence bound applied to trees).

\(PUCT(s,a)=\bar{X}(s,a)+c_{\text{puct}}P(s,a)\frac{\sqrt{N(s)}}{1+N(s,a)}\)

The only new term in this function is P(s,a), which is the prior probability of taking an action given by a neural network. The goal of evaluating choice in our search this way is to reduce the amount of searches we have to perform in order to get strong performance and also improve upon the performance of UCT. Below we implement the algorithm. Note it is very similar to our previous MCTS but with the ability to handle tensors (for our PyTorch model) and implementation of PUCT instead of UCT.


class MCTSNode:
    def __init__(self, game, parent=None, move=None):
        self.game = game
        self.parent = parent
        self.move = move
        self.children = {}
        self.N = 0  # Visit count
        self.W = 0  # Total value
        self.Q = 0  # Mean value
        self.P = None  # Prior probability from NN

    def is_leaf(self):
        return len(self.children) == 0

def mcts_search(root_state, model, simulations=100, c_puct=1.0):
    root = MCTSNode(root_state)

    # Initial NN prediction
    input_tensor = encode_board(root_state.board, root_state.player).unsqueeze(0)
    logits, _ = model(input_tensor)
    logits = logits.squeeze()
    legal_moves = root_state.get_legal_actions()

    mask = torch.zeros_like(logits, dtype=torch.bool)
    mask[legal_moves] = True
    masked_logits = logits.masked_fill(~mask, float('-1e9'))
    policy = F.softmax(masked_logits, dim=0).detach().cpu().numpy()

    root.P = policy

    for _ in range(simulations):
        node = root
        path = []

        # Selection
        while not node.is_leaf():
            best_score = -float("inf")
            for a, child in node.children.items():
                u = c_puct * (child.P if child.P is not None else 1.0) * math.sqrt(node.N) / (1 + child.N)
                score = child.Q + u
                if score > best_score:
                    best_score = score
                    best_move = a
                    best_child = child
            node = best_child
            path.append(node)

        # Expansion
        if not node.game.is_terminal():
            legal_moves = node.game.get_legal_actions()
            for move in legal_moves:
                next_game = node.game.move(move)
                child = MCTSNode(next_game, node, move)
                node.children[move] = child

            expand_move = random.choice(legal_moves)
            leaf_node = node.children[expand_move]

            input_tensor = encode_board(leaf_node.game.board, leaf_node.game.player).unsqueeze(0)
            logits, value = model(input_tensor)
            logits = logits.squeeze()

            leaf_legal = leaf_node.game.get_legal_actions()
            mask = torch.zeros_like(logits, dtype=torch.bool)
            mask[leaf_legal] = True
            probs = F.softmax(logits.masked_fill(~mask, float('-1e9')), dim=0).detach().cpu().numpy()

            for move in leaf_legal:
                next_game = leaf_node.game.move(move)
                child_node = MCTSNode(next_game, leaf_node, move)
                child_node.P = probs[move]
                leaf_node.children[move] = child_node

            v = value.item()
        else:
            v = node.game.get_result()

        # Backpropagation
        for n in path:
            n.N += 1
            n.W += v
            n.Q = n.W / n.N
            v = -v

    # Build π from visit counts
    visit_counts = np.zeros(COLUMN_COUNT)
    for move, child in root.children.items():
        visit_counts[move] = child.N

    pi = visit_counts / (np.sum(visit_counts) + 1e-8)
    best_move = np.argmax(visit_counts)

    return best_move, pi
  

For our network we use a typical CNN architecture with input size of [batch_size, 2 (positions of player 1 pieces and positions of player 2's pieces), # of rows (6 for connect four), # of columns (7 for connect four)]. The output of this CNN is a probability distribution over all the possible moves to play next known as the policy head and here we will use log_softmax to convert the outputs of the network to a distribution. It also outputs a scalar (between -1 and 1) known as the value head that estimates the expected outcome of the current state (1 = certain win, 0 = draw, -1 = certain loss) and to constrain the output of the network to that range the tanh function is used.

One final important thing to consider is illegal moves, for example if a column in connect four has reached the top you can't play another piece above it. To solve this we can first create a mask vector (M) over the columns (\(M_{i}\) is 1 if column i is legal and 0 if column i illegal). We can apply our mask to the output of the network and renormalise to give our decision. We don't worry about this mask during training (only evaluation) as during training the network will naturally learn to not choose illegal moves. We implement a ResNet in the code below.


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return F.relu(out)

class AlphaZeroCNN(nn.Module):
    def __init__(self, board_height=6, board_width=7, action_size=7, num_res_blocks=5):
        super().__init__()
        self.board_height = board_height
        self.board_width = board_width
        self.action_size = action_size

        # Initial convolution
        self.conv = nn.Conv2d(2, 64, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(64)

        # Residual tower
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_res_blocks)]
        )

        # --- Policy Head ---
        self.policy_conv = nn.Conv2d(64, 2, kernel_size=1)
        self.policy_bn = nn.BatchNorm2d(2)
        self.policy_fc = nn.Linear(2 * board_height * board_width, action_size)

        # --- Value Head ---
        self.value_conv = nn.Conv2d(64, 1, kernel_size=1)
        self.value_bn = nn.BatchNorm2d(1)
        self.value_fc1 = nn.Linear(board_height * board_width, 64)
        self.value_fc2 = nn.Linear(64, 1)

    def forward(self, x, legal_moves_mask=None):
        # Shared body
        x = F.relu(self.bn(self.conv(x)))       # [B, 64, 6, 7]
        x = self.res_blocks(x)                  # [B, 64, 6, 7]

        # --- Policy Head ---
        p = F.relu(self.policy_bn(self.policy_conv(x)))  # [B, 2, 6, 7]
        p = p.view(p.size(0), -1)                        # [B, 2*6*7]
        p = self.policy_fc(p)                            # [B, 7] (logits)

        if legal_moves_mask is not None:
            p = p.masked_fill(~legal_moves_mask, float('-1e9'))

        log_probs = F.log_softmax(p, dim=1)              # [B, 7]

        # --- Value Head ---
        v = F.relu(self.value_bn(self.value_conv(x)))    # [B, 1, 6, 7]
        v = v.view(v.size(0), -1)                        # [B, 6*7]
        v = F.relu(self.value_fc1(v))                    # [B, 64]
        v = torch.tanh(self.value_fc2(v))                # [B, 1]

        return log_probs, v.squeeze(1)
  

The loss function we use for the network in training is:

\(\mathcal{L}=\mathcal{L}_{\text{value}}+\mathcal{L}_{\text{policy}}=(v-z)^{2}-\sum_{i}\pi_{i}\log p_{a}\)

Where v is the predicted value from the network (between -1 and 1), z is the actual game outcome (-1, 0 or 1), \(\pi\) is the target policy from Monte Carlo Tree search and p is the predicted policy from the network (after softmax over the logits). Below is the loss function in code.


def alpha_zero_loss(pred_policy, pred_value, target_policy, target_value):
    policy_loss = -torch.sum(target_policy * pred_policy, dim=1).mean()
    value_loss = F.mse_loss(pred_value, target_value)
    return policy_loss + value_loss
  

The final piece of the AlphaZero algorithm is that we train the network through playing with itself. How this is done is by setting both players in the game to the current network and running through a series of games (typically 100-1000) where the moves made are done using MCTS with PUCT. For every move in every game we store three variables \((s_{t},\pi_{t},z)\). \(s_{t}\) is the board game state, \(\pi_{t}\) is the target policy derived from \(\pi_{t}(a)=\frac{N(s_{t},a)^{\alpha}}{\sum_{b}N(s_{t},b)^{\alpha}}\) (\(\alpha\) is a temperature parameter to adjust exploration) so it is a probability distribution over the possible actions at a given state and z which records if the current player won, lost or drew (-1,0,1) at the end of the game. At the end of these plays we have data which we can use to train our network to play better. We can repeat this process until the network is sufficiently trained. You can see this implemented in train.py and play.py.

After all this is done we will have an intelligent agent we can play against in connect four. (For anyone reading this now the network is still training I will post the rest of this tutorial in a few days)