Easy way to understand Decision Tree Pruning

Eventually in your journey with Machine Learning algorithms, comes a time when your algorithm does not work well. In Decision Trees, it’s important to understand decision tree pruning in case you run into a problem where your algorithm is not performing up to the mark on a testing dataset, and changing your data and filtering it out is not helping the case.

Every machine learning algorithm including Decision trees has its own prime and it cannot be pushed beyond a point. Pruning comes into the picture when your decision trees are overfitting on the given training data and need to be manipulated a little here and there to ensure your trees perform optimally as you will see in this article.

Feel free to skip to any section of your choice using the table of contents below.

What are Decision Trees?

As the name suggests a decision tree works just like a tree with branches. The foundation of the tree or the base is the root node. From there flows a series of decision nodes that represent choices or decisions to be made. Choices or decision nodes are called leaf nodes which represent the result of the decisions. A decision node represents a split point, and leaf nodes that stem from a decision node represent the possible answers.

Like leaves grow on the branches, similarly, the leaf nodes grow out of the decision nodes on the branch of a Decision Tree. Every subsequent section of a Decision Tree is therefore called a “branch.” An example of this is when the question is, “Are you a diabetic?” and the leaf nodes can be ‘yes’ or ‘no’. It is in this case when your tree overfits the data, decision tree pruning can cut down the number of leaves to focus on giving a more generalistic decision tree.

As overfitting works, for errorless data, you can always construct a decision tree that correctly labels every element of the training set, but it may be exponential in size.

Find out more about Decision Trees here: Decision Trees

Some key terminologies

  • A Root node is at the base of the decision tree.
  • The process of dividing a node into sub-nodes is called Splitting.
  • When a sub-node is further split into additional sub-nodes it is called a Decision node.
  • When a sub-node depicts the possible outcomes and cannot be further split it is a Leaf nod.
  • The process by which sub-nodes of a decision tree are removed is called Pruning.
  • The subsection of the decision tree consisting of multiple nodes is called Branches.

What is Pruning?

According to Google, to prune is to “trim (a tree, shrub, or bush) by cutting away dead or overgrown branches or stems, especially to encourage growth.” Ideally, this is exactly what decision tree pruning looks like where the overgrown branches are cut down to encourage a more effective decision tree model. It is a not-so-famous but effective approach to ensure that your Decision Tree is performing at its best even on real-world data without overfitting on a training dataset.

Gini and Entropy are loss functions that do a similar thing to ensure that your Decision Trees are not overfitting or underfitting the data by keeping the split impurity in check and seeing to it that the information gained at each split is valid and of use to the algorithm. To know more about Gini and Entropy, check out our previous article here: Decision Tree Gini and Entropy

To see how Decision Tree Pruning works in action, watch this amazing YouTube video which explains how it works: Decision Tree Pruning

Difference between Pre-Pruning and Post Pruning

As the names suggest clearly, pre-pruning or early stopping is the process of stopping the tree before it has completely classified all the branches of the Training set as you expect it to have arrived at a good-fit model by then and to avoid it from overfitting it is given a preemptive stop. The process of post pruning is when you remove the excess leaves after the tree has finished the classification process.

How to Prune your dataset in Python?

The DecisionTreeClassifier in Scikit Learn library of Python provides parameters such as min_sample_leaf and max_depth which help the algorithm avoid Overfitting. Cost complexity pruning is another that can manipulate the size of a tree. This pruning technique is parameterized by the cost complexity parameter, ccp_alpha. Greater the ccp_alpha greater the number of nodes pruned. The effect of ccp_alpha on regularizing the trees and how to choose a ccp_alpha based on validation scores can help you decide when and how much to prune your dataset to ensure that your algorithm fits the data well.

Find the code to use in your algorithm here:

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

clf = DecisionTreeClassifier(random_state=0)
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

clf = DecisionTreeClassifier(random_state=0)
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities

fig, ax = plt.subplots()
ax.plot(ccp_alphas[:-1], impurities[:-1], marker=".",color = 'r', drawstyle="steps-post")
ax.set_xlabel("effective alpha value")
ax.set_ylabel("total impurity of leaves in Decision Tree")
ax.set_title("Total Impurity vs effective alpha for a given training set")

The above code from Scikit Learn will help you find out the Total Impurity vs effective alpha for a given training set which looks as follows:

Decision Tree Pruning using Python Scikit Learn
Total Impurity vs effective alpha for a given training set

This impurity measure and effective alpha will help you ascertain the right value for your alpha.

As a chosen next step, Use the following code to train decision trees with the chosen effective alpha values in a loop as given below that records all the values for the alpha in the classifiers:

clfs = []
for ccp_alpha in ccp_alphas:
    clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
    clf.fit(X_train, y_train)
    clfs.append(clf)
print(
    "Number of nodes in the last tree is: {} with ccp_alpha: {}".format(
        clfs[-1].tree_.node_count, ccp_alphas[-1]
    )
)

To see the alpha vs accuracy in a graph to find out the best value of alpha for your data for training and testing data use the following code:

train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]

fig, ax = plt.subplots()
ax.set_xlabel("alpha")
ax.set_ylabel("accuracy")
ax.set_title("Accuracy vs alpha for training and testing sets")
ax.plot(ccp_alphas, train_scores, marker="o", label="train", drawstyle="steps-post")
ax.plot(ccp_alphas, test_scores, marker="o", label="test", drawstyle="steps-post")
ax.legend()
plt.show()

A graph will be produced which will help you choose the best alpha value to get the best accuracy for your model.

Pruning decision trees using Python

As observed from the graph, the accuracy is highest for alpha = 0.015 where the pruning proves to be helpful. (example purposes only)

For seeing how the number of nodes and tree depth decreases as alpha increases and for the full code demonstration: Post pruning decision trees with cost complexity pruning.

Conclusion

While it is not assured if your decision tree needs pruning, it can be possible to find it out if your sum of squared errors is not very high indicating that your model must be overfitting on the training data which is not good. This calls for pruning which can reduce the overfit but increase the SSE (Sum of Squared Errors) by a little with every leaf pruned but in good faith. Decision Tree Pruning is an essential way by which a Decision tree can be made to perform well on a real-world level.

If you think your data still requires more classification and one decision tree is not suiting your needs, try using a Random Forest which can classify multivariate data better than a single decision tree and is considered one of the best supervised machine learning algorithms out there.

For more such content, check out our website: Buggy Programmer

Share this post

Read next...

Subscribe
Notify of
guest

0 Comments
Inline Feedbacks
View all comments