决策树(Decision Tree)
从零到实战:原理 + 公式 + 代码 + 可视化 + 剪枝 + 进阶技巧
一句话定义:
决策树 = 像“问答游戏”一样,通过一系列 if-else 问题,把数据划分到不同类别/数值
一、核心思想(类比)
| 现实场景 | 决策树 |
|---|---|
| 医生诊断 | “发烧?→ 是 → 测血常规 → …” |
| 贷款审批 | “收入>5w?→ 是 → 信用分>700?→ …” |
| 玩“20 Questions” | 不断提问缩小范围 |
本质:
在每个节点问一个最能区分数据的问题 → 递归划分 → 叶子节点 = 预测结果
二、决策树工作流程(3 步)
graph TD
A[根节点] -->|特征A > 阈值| B[左子树]
A -->|特征A ≤ 阈值| C[右子树]
B --> D[继续分裂]
C --> E[叶子节点: 类别1]
D --> F[叶子节点: 类别0]
三、如何选择“最佳问题”?—— 分裂准则
| 任务 | 准则 | 公式 | 解释 |
|---|---|---|---|
| 分类 | 信息增益(ID3) | $ IG = H(D) – H(D|A) $ | 熵下降最多 |
| 信息增益率(C4.5) | $ IGR = \frac{IG}{H(A)} $ | 防偏向多值特征 | |
| 基尼指数(CART) | $ Gini = 1 – \sum p_i^2 $ | 越小越纯 | |
| 回归 | 均方误差(MSE) | $ \text{Var} = \sum (y_i – \bar{y})^2 $ | 方差下降最大 |
熵(Entropy)公式
$$
H(D) = -\sum_{k=1}^{K} p_k \log_2 p_k
$$
熵越高 → 混乱度越高 → 需要更好问题来降低
四、分类树分裂示例
| 特征 | 玩不玩游戏 |
|---|---|
| 天气=晴 | 是 |
| 天气=阴 | 是 |
| 天气=雨 | 否 |
| 温度>30 | 否 |
→ 问 “天气=晴?” → 熵从 0.94 → 0 → 信息增益最大
五、Python 完整实战(5 分钟跑通)
# ===== 1. 导入库 =====
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt
# ===== 2. 加载鸢尾花数据 =====
iris = load_iris()
X, y = iris.data, iris.target
# ===== 3. 划分训练/测试集 =====
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# ===== 4. 训练决策树 =====
tree = DecisionTreeClassifier(
max_depth=3, # 限制深度防过拟合
criterion='gini', # 或 'entropy'
random_state=42
)
tree.fit(X_train, y_train)
# ===== 5. 预测与评估 =====
y_pred = tree.predict(X_test)
print(f"准确率: {accuracy_score(y_test, y_pred):.3f}")
print(classification_report(y_test, y_pred, target_names=iris.target_names))
输出:
准确率: 1.000
precision recall f1-score support
setosa 1.00 1.00 1.00 10
versicolor 1.00 1.00 1.00 9
virginica 1.00 1.00 1.00 11
六、可视化决策树(超直观!)
plt.figure(figsize=(15, 10))
plot_tree(
tree,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True,
rounded=True,
fontsize=10
)
plt.title("鸢尾花决策树(深度=3)")
plt.show()
你会看到:
- 每个节点:
特征 <= 阈值 gini值越小 → 越纯- 颜色:橙=多数类
七、回归树实战(预测连续值)
from sklearn.tree import DecisionTreeRegressor
import numpy as np
# 模拟数据:面积 → 房价
X = np.array([50, 60, 70, 80, 90, 100, 110, 120]).reshape(-1, 1)
y = np.array([150, 180, 210, 240, 270, 300, 330, 360])
reg_tree = DecisionTreeRegressor(max_depth=2, random_state=42)
reg_tree.fit(X, y)
# 预测
X_test = np.array([85, 95]).reshape(-1, 1)
print("预测房价:", reg_tree.predict(X_test))
输出:
预测房价: [255. 285.]
八、决策边界可视化(2D)
from sklearn.inspection import DecisionBoundaryDisplay
import numpy as np
# 只用两个特征
X_vis = X_train[:, [2, 3]] # 花瓣长度和宽度
tree_vis = DecisionTreeClassifier(max_depth=3).fit(X_vis, y_train)
plt.figure(figsize=(8, 6))
DecisionBoundaryDisplay.from_estimator(
tree_vis, X_vis, cmap='Pastel1', response_method="predict"
)
plt.scatter(X_vis[:, 0], X_vis[:, 1], c=y_train, edgecolor='k', cmap='Set1')
plt.xlabel(iris.feature_names[2])
plt.ylabel(iris.feature_names[3])
plt.title('决策树决策边界')
plt.show()
→ 你会看到 阶梯状边界(非线性!)
九、剪枝(Pruning)—— 防过拟合
1. 预剪枝(训练时限制)
tree = DecisionTreeClassifier(
max_depth=4,
min_samples_split=5, # 节点至少5个样本才分裂
min_samples_leaf=2, # 叶子节点至少2个样本
max_features='sqrt' # 每次只看部分特征
)
2. 后剪枝(Cost-Complexity Pruning)
path = tree.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas
best_tree = None
best_score = 0
for alpha in ccp_alphas[:-1]: # 最后一个是根节点
tree_tmp = DecisionTreeClassifier(random_state=42, ccp_alpha=alpha)
tree_tmp.fit(X_train, y_train)
score = tree_tmp.score(X_test, y_test)
if score > best_score:
best_score = score
best_tree = tree_tmp
print(f"最佳 alpha: {best_tree.ccp_alpha:.5f}, 准确率: {best_score:.3f}")
十、特征重要性(可解释性强!)
importances = tree.feature_importances_
feat_names = iris.feature_names
for name, imp in zip(feat_names, importances):
print(f"{name}: {imp:.3f}")
# 可视化
plt.barh(feat_names, importances)
plt.xlabel('重要性')
plt.title('特征重要性')
plt.show()
输出示例:
sepal length (cm): 0.000
sepal width (cm): 0.014
petal length (cm): 0.577
petal width (cm): 0.409
→ 花瓣长度最重要!
十一、优缺点总结
| 优点 | 缺点 |
|---|---|
| 可视化、易解释 | 易过拟合 |
| 不需标准化 | 边界是轴对齐的 |
| 支持多分类/回归 | 对噪声敏感 |
| 自动特征选择 | 可能陷入局部最优 |
解决方案:随机森林 = 100棵树投票 → 解决过拟合
十二、一键完整代码(复制即用)
# ===== 决策树完整流程 =====
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
# 1. 数据
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 2. 模型(预剪枝)
tree = DecisionTreeClassifier(max_depth=3, min_samples_leaf=2, random_state=42)
tree.fit(X_train, y_train)
# 3. 评估
print("准确率:", accuracy_score(y_test, tree.predict(X_test)))
# 4. 可视化树
plt.figure(figsize=(16, 8))
plot_tree(tree, feature_names=iris.feature_names, class_names=iris.target_names,
filled=True, rounded=True, fontsize=11)
plt.title("决策树结构")
plt.show()
# 5. 特征重要性
plt.figure(figsize=(8, 4))
plt.barh(iris.feature_names, tree.feature_importances_)
plt.xlabel('重要性')
plt.title('特征重要性')
plt.show()
十三、总结公式卡
| 项目 | 公式 |
|---|---|
| 熵 | $ H = -\sum p \log p $ |
| 信息增益 | $ IG = H_{\text{父}} – \sum w_i H_{\text{子}} $ |
| 基尼 | $ Gini = 1 – \sum p_i^2 $ |
| 回归分裂 | $ \text{方差下降} = Var_{\text{父}} – (w_L Var_L + w_R Var_R) $ |
你想深入哪一步?
- 随机森林(100棵树投票)
- 后剪枝 + 交叉验证选最佳树
- 用真实数据做项目(如泰坦尼克生存预测)
- 手写决策树算法(从分裂到预测)
回复 1–4,我立刻带你实战!