Decision tree is one of the most powerful yet simplest supervised machine learning algorithm, it is used for both classification and regression problems also known as Classification and Regression tree (CART) algorithm.
Decision tree classifiers are used successfully in many diverse areas, their most important feature is the capability of capturing descriptive decision-making knowledge from the supplied data.
Decision tree or CART algorithm:
CART algorithms recursively split the data into partitions. You can keep track of these partitions in a tree structure. During testing, a test point traverses the tree until it falls into a leaf. Each leaf is associated with one of the data partitions, and you assign the test point the most common label within that partition (or average label in case of regression).
Why Decision tree?
1. It can be easily visualized so that human can understand what’s going on inside, like a flowchart, where each level is a question with a yes or no answer.
2. It is easy to test, as once tree is built and if any new test point comes, it just needs to be traversed in order to give prediction.
Below figure would be the simple example of Decision tree, consider the scenario where we need to decide whether we need to go to market or not to buy shampoo, quite a hard decision, isn’t it?.
Look at the visualized decision tree to see how simple questions can be used to split data.
Deep diving into Decision tree:
Unfortunately, in real world we don’t get these easy tasks to run our CART algorithm on ☹, instead most of the time we convert textual data in 0s and 1s in order to classify our dataset and that makes our algorithm bit complicated to understand. In this blog I will show you how to make Decision tree from scratch in python and then we will walk through titanic dataset from Kaggle and run our algorithm on it.
Building blocks:
There are two main building blocks of decision tree:
a) Partition impurity: This will decide how to partition data and on which feature.
b) Since it will be a recursive algorithm, we need to decide at what point our algorithm should stop.
Partition’s impurity: When building a tree, you want each partition to become “purer” (i.e. containing only data from a single class). If you partition is pure, you can easily and confidently assign labels to new data points that lie within a partition. We can use an impurity metric to measure a partition’s purity compared to other partitions.
Type of impurity functions
Gini Impurity: Gini impurity of a leaf with corresponding set S is the probability that two points, both picked uniformly at random from set S, have different labels. Intuitively we sum over all possible labels and compute the probability that the first point has that particular label (which has probability pk) and that second point does not have that label (1 — pk ).
Squared-loss impurity:
We can tweak classification trees for continuous valued labels (regression). Because labels are no longer categorical, we redefine impurity such that it captures the spread of values in each node. The prediction made by a regression tree for a leaf with corresponding set S is simply the mean label ys.
Finding the best split: Remember, you evaluate the quality of a split of a parent set Sp into two sets SL and SR by the weighted impurity of the two branches.
Final and most important part is when to stop recursive function? Here, we formalize this stopping criterion:
The CART algorithm will stop in exactly two cases:
1. If all data points in the data set share the same label we stop splitting and create a leaf with label y.
2. If all the data points in the data set share the same feature, we create a leaf with most common label y for classification and average label for regression.
Now, since we saw both the building blocks to build decision tree algorithm let’s start making it from scratch, for this blog I will be using “Squared loss impurity” function, as it is simple to understand and if you want to learn Gini or any other, you would find ample resources online.
Titanic Dataset
To evaluate our model, I am using famous titanic dataset from Kaggle, this dataset consist of details passengers who were aboard on titanic with class label “survived” or “not survived”, we need to basically predict the label for any novel dataset.
I have done some data cleaning that is not scope of this blog, you can find full code by clicking on below link.
Algorithm
Our algorithm will be divided into 4 parts or functions:
a) Square impurity:
This function will calculate impurity of the labels; it will take input a vector of n labels and output the corresponding squared loss impurity (below is the function)
b) Split function:
This function will take dataset with labels as input and computes the best feature and cut value of an optimal split based on the squared error impurity we defined above. The cut value should be the average of the values in the dimension where two data points are split.
In order to make our life easy, we first sort the dataset features and loop through and see where two features is different and take average of those two points and calculate loss and this will do for ever feature in the dataset. For example lets say we have 5 features and and their values are 0s, 1s and 2s like [0,1,0,2,1] , now after sort this will become, [0,0,1,1,2] and we will loop through this and see the change in feature value.
We also have created two helper function:
- Check if all the values in dataset or in column are same or not and if it is same we don’t split.
- Second function basically checks majority of labels in the list.
We have also created TreeNode class in order to create tree like structure
c) CART function:
This function will build tree and before that it will check the base cases and stop the recursive call if any of those met or else it will keep splitting dataset.
d) Predict function:
The final function will just take the test point and traverse the tree and get to the leaf and return the prediction.
Model evaluation on titanic dataset
As we can see our model that we built from scratch is as efficient as Scikit-learn’s model.
If you have any question or feedback, please let me know and thanks for reading!
References: