BART’s tree structure implementation
Published:
As we commented in the previous post, BART is a sum-of-trees model where each tree is a decision tree. Thus, to implement the BART model first we have to implement the tree structure used by it. The tree implementation will need:
- A data structure to link the nodes. This should allow us to randomly access a node in the tree and to easily add and delete a node.
- Two types of nodes that make up the tree: splitting nodes and leaf nodes. The splitting nodes will be responsible for the division of the predictor space and gather the logic to traverse tree given an element $x$; the leaf nodes have the responses $\mu_{ij}$ for the tree.
- Functions to grow and prune the tree.
- Checks of correctness of the tree.
Data structure to link the nodes
Two types of data structures were consider:
- A series of linked nodes were each node has a link to its left and right children nodes, if they exist. The root node represents the whole tree, since you can traverse the whole tree from it.
- A dictionary that represents the nodes stored in breadth-first order, based in the array method for storing binary trees.
We started coding the linked nodes structure, inspired by this implementation of a binary tree. This structure made creation and deletion of nodes easy since we only needed to replace links to nodes. But, early on, we realized that this implementation would not allow us to randomly access a node in the tree. Thus, we dropped it only maintaining code to represent the tree as a string.
Therefore, we thought of a structure that would explicitly represent the nodes and its positions. Assume we have a complete binary tree, binary tree in which every level, except possibly the last, is completely filled, and all nodes are as far left as possible. If we walk through the tree in a breadth-first order and number the nodes from zero to number of nodes minus one we can identify every node and its position in the tree structure by this number.
A complete binary tree is efficiently implemented as an array, where a node at location $i$ has children at indexes $2i + 1$ and $2i + 2$ and a parent at location $\left \lfloor{(i - 1) / 2}\right \rfloor $. Since Python doesn’t have a built-in array structure, we consider two basic structures: list
and dict
.
Note that, although we consider that the indices are taken from numbering a complete binary tree, BART does not necessary construct this type of trees. The only thing we can ensure about the tree structure is that each node of the tree has exactly zero or two children. Yet, this numbering will prove us useful for indexing our structure.
If we try to implement this structure using a list, we would end up with a lot of wasted space since we would have to create dummy nodes to represent non-existent nodes. Thus, we ended up coding the tree structure as a dictionary, where the keys represent the nodes position and the values represent the nodes itself.
Tree nodes
Both splitting and leaf nodes inherit from a base class called BaseNode
which has two attributes: index and depth.
The splitting nodes should maintain the splitting variable and the value to split. Since BART allows for quantitative and qualitative splitting nodes we should make that distinction possible.
The leaf nodes only hold the response result of the tree for a particular predictor space.
Functions to grow and prune the tree
Every tree can only grow from a leaf node. When this happen, the old node is replaced for a splitting node and two leaf nodes. On the other hand, when we prune a tree, we select a prunable node (splitting node that have two leaf nodes as children) from the tree, delete its children and replace the node for a leaf node.
Checks of correctness of the tree
Although the user will not be creating trees, but since we want our code to fail as soon as something bad happens (specially during develping), we added checks of correctness of the tree and raised exceptions if something was wrong. We also created tests to control that after each commit the implementation is still correct.
All the code for the implementation of BART can be viewed in this branch of PyMC.