Decision Trees in scikit-learn: A Practical Intro
A working introduction to decision trees in scikit-learn covering splitting criteria, overfitting, max_depth tuning, visualization, and the path to random forests.
What you'll learn
- ✓How a decision tree splits data using Gini or entropy
- ✓Why unconstrained trees almost always overfit
- ✓How max_depth and min_samples_leaf control complexity
- ✓How to visualize a trained tree for stakeholder review
- ✓How a single tree leads naturally to random forests
Prerequisites
- •A solid grasp of [what machine learning is](/blog/what-is-machine-learning)
- •Comfort with [pandas dataframes](/blog/pandas-dataframes-basics)
- •Familiarity with [train/test split and metrics](/blog/ml-train-test-split-and-metrics)
Decision trees are the workhorse of tabular machine learning. They handle mixed numeric and categorical features, require no scaling, and produce models that a non-technical reader can follow row by row. They also overfit ferociously if you let them, so a working knowledge of their hyperparameters matters more than for almost any other model.
How a tree learns
A decision tree partitions the feature space using a sequence of yes-or-no questions. At each node the algorithm searches over every feature and every candidate split point, picks the one that makes the resulting child nodes most pure, and recurses. Leaves predict the majority class for classification, or the mean target for regression.
“Most pure” is measured by an impurity metric. For classification the two common choices are Gini impurity and entropy. Both reach zero when a node contains only one class and a maximum when classes are perfectly mixed. In practice the two produce nearly identical trees, so the default Gini is fine almost all the time.
For regression the criterion is mean squared error: choose the split that minimises the average squared distance between each row’s target and its leaf’s mean.
A first model
We will train a classifier on the classic Iris dataset.
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
iris = load_iris(as_frame=True)
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=0
)
tree = DecisionTreeClassifier(criterion="gini", random_state=0)
tree.fit(X_train, y_train)
preds = tree.predict(X_test)
print("train acc:", round(tree.score(X_train, y_train), 3))
print("test acc:", round(accuracy_score(y_test, preds), 3))
Run this and you will almost certainly see a training accuracy of 1.0. That is your first lesson: an unconstrained tree grows until every leaf is pure, which means it has memorised the training set.
Overfitting and the depth knob
The most important hyperparameter is max_depth. With max_depth=None, the tree can grow as deep as it likes. Set it to a small integer and the tree is forced to commit to a coarser decision rule.
from sklearn.tree import DecisionTreeClassifier
for d in [1, 2, 3, 4, 6, None]:
t = DecisionTreeClassifier(max_depth=d, random_state=0).fit(X_train, y_train)
print(f"depth={d!s:>4} train={t.score(X_train, y_train):.2f} "
f"test={t.score(X_test, y_test):.2f}")
You will typically see training accuracy climb monotonically while test accuracy peaks at some middling depth and then plateaus or drops. That is the visible signature of overfitting.
Two other knobs are worth knowing. min_samples_leaf sets the smallest number of training rows that may end up in a leaf, which is a more data-driven way to limit depth. min_samples_split sets the smallest number of rows for which a node will even consider splitting. Setting min_samples_leaf=5 is a reasonable default for noisy real-world data.
Visualising the tree
Decision trees are one of the few models you can read directly. scikit-learn ships two helpers.
from sklearn.tree import plot_tree, export_text
import matplotlib.pyplot as plt
small_tree = DecisionTreeClassifier(max_depth=3, random_state=0)
small_tree.fit(X_train, y_train)
print(export_text(small_tree, feature_names=list(X.columns)))
plt.figure(figsize=(12, 6))
plot_tree(
small_tree,
feature_names=list(X.columns),
class_names=iris.target_names,
filled=True,
)
plt.show()
export_text is fantastic for code reviews because it produces an ASCII tree you can paste into a pull request description. plot_tree is what you want for slides or a notebook handed to a product manager.
Reading feature importances
A trained tree exposes a feature_importances_ array that ranks features by how much they reduced impurity across all splits.
import pandas as pd
imp = pd.Series(small_tree.feature_importances_, index=X.columns)
print(imp.sort_values(ascending=False))
Treat these as a directional signal, not a causal claim. Two correlated features can end up with the importance split unevenly because whichever feature happens to be chosen first absorbs most of the credit.
From a single tree to a forest
A single tree is high variance: small changes in training data can shuffle the splits at the top of the tree and produce a very different model. That instability is exactly what ensembles are designed to exploit. A random forest trains many trees on bootstrapped samples of the rows and random subsets of the columns, then averages their predictions.
from sklearn.ensemble import RandomForestClassifier
forest = RandomForestClassifier(
n_estimators=300,
max_depth=None,
min_samples_leaf=2,
random_state=0,
n_jobs=-1,
)
forest.fit(X_train, y_train)
print("forest test acc:", round(forest.score(X_test, y_test), 3))
You will usually see a small but real improvement over the best single tree, and the variance from run to run drops dramatically. If you want to push further still, gradient boosted trees from libraries like XGBoost or LightGBM are the standard next step.
When to use a single tree
Even with forests available, a shallow single tree still earns its keep. It is the right model when interpretability is non-negotiable, when latency must be a handful of microseconds, or when the tree itself is a deliverable, for example as a credit underwriting rulebook.
Wrap up
Decision trees are the friendliest model to learn and the easiest to misuse. Always constrain depth, always evaluate on a held-out split, and always look at the tree itself. When the test accuracy stops improving, graduate to a random forest. Both models reward the same disciplined evaluation you would give any other learner; revisit the metrics article whenever you size up a new dataset.
Related articles
- Machine Learning ML Decision Trees and Random Forests
How decision trees work, why a single tree overfits, and how random forests solve that problem by averaging many trees trained on different data.
- Machine Learning Feature Engineering Basics for Tabular Data
Practical feature engineering for tabular machine learning, covering encoding, scaling, missing value handling, interaction features, and how to avoid data leakage.
- Machine Learning Linear Regression: The First Model You Should Train
A practical introduction to linear regression with scikit-learn covering OLS, evaluation with R-squared and MAE, and the assumptions that make or break the model.
- Machine Learning Logistic Regression for Binary Classification
Learn how logistic regression turns a linear score into a probability, how to train it with scikit-learn, and how to evaluate binary classifiers using ROC-AUC.