决策树(Decision Tree)

决策树(Decision Tree)

从零到实战:原理 + 公式 + 代码 + 可视化 + 剪枝 + 进阶技巧

一句话定义
决策树 = 像“问答游戏”一样,通过一系列 if-else 问题,把数据划分到不同类别/数值


一、核心思想(类比)

现实场景决策树
医生诊断“发烧?→ 是 → 测血常规 → …”
贷款审批“收入>5w?→ 是 → 信用分>700?→ …”
玩“20 Questions”不断提问缩小范围

本质
在每个节点问一个最能区分数据的问题 → 递归划分 → 叶子节点 = 预测结果


二、决策树工作流程(3 步)

graph TD
    A[根节点] -->|特征A > 阈值| B[左子树]
    A -->|特征A ≤ 阈值| C[右子树]
    B --> D[继续分裂]
    C --> E[叶子节点: 类别1]
    D --> F[叶子节点: 类别0]

三、如何选择“最佳问题”?—— 分裂准则

任务准则公式解释
分类信息增益(ID3)$ IG = H(D) – H(D|A) $熵下降最多
信息增益率(C4.5)$ IGR = \frac{IG}{H(A)} $防偏向多值特征
基尼指数(CART)$ Gini = 1 – \sum p_i^2 $越小越纯
回归均方误差(MSE)$ \text{Var} = \sum (y_i – \bar{y})^2 $方差下降最大

熵(Entropy)公式

$$
H(D) = -\sum_{k=1}^{K} p_k \log_2 p_k
$$

熵越高 → 混乱度越高 → 需要更好问题来降低


四、分类树分裂示例

特征玩不玩游戏
天气=晴
天气=阴
天气=雨
温度>30

→ 问 “天气=晴?” → 熵从 0.94 → 0 → 信息增益最大


五、Python 完整实战(5 分钟跑通)

# ===== 1. 导入库 =====
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import matplotlib.pyplot as plt

# ===== 2. 加载鸢尾花数据 =====
iris = load_iris()
X, y = iris.data, iris.target

# ===== 3. 划分训练/测试集 =====
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# ===== 4. 训练决策树 =====
tree = DecisionTreeClassifier(
    max_depth=3,           # 限制深度防过拟合
    criterion='gini',      # 或 'entropy'
    random_state=42
)
tree.fit(X_train, y_train)

# ===== 5. 预测与评估 =====
y_pred = tree.predict(X_test)
print(f"准确率: {accuracy_score(y_test, y_pred):.3f}")
print(classification_report(y_test, y_pred, target_names=iris.target_names))

输出

准确率: 1.000
              precision    recall  f1-score   support
      setosa       1.00      1.00      1.00        10
  versicolor       1.00      1.00      1.00         9
   virginica       1.00      1.00      1.00        11

六、可视化决策树(超直观!)

plt.figure(figsize=(15, 10))
plot_tree(
    tree,
    feature_names=iris.feature_names,
    class_names=iris.target_names,
    filled=True,
    rounded=True,
    fontsize=10
)
plt.title("鸢尾花决策树(深度=3)")
plt.show()

你会看到:

  • 每个节点:特征 <= 阈值
  • gini 值越小 → 越纯
  • 颜色:橙=多数类

七、回归树实战(预测连续值)

from sklearn.tree import DecisionTreeRegressor
import numpy as np

# 模拟数据:面积 → 房价
X = np.array([50, 60, 70, 80, 90, 100, 110, 120]).reshape(-1, 1)
y = np.array([150, 180, 210, 240, 270, 300, 330, 360])

reg_tree = DecisionTreeRegressor(max_depth=2, random_state=42)
reg_tree.fit(X, y)

# 预测
X_test = np.array([85, 95]).reshape(-1, 1)
print("预测房价:", reg_tree.predict(X_test))

输出

预测房价: [255. 285.]

八、决策边界可视化(2D)

from sklearn.inspection import DecisionBoundaryDisplay
import numpy as np

# 只用两个特征
X_vis = X_train[:, [2, 3]]  # 花瓣长度和宽度
tree_vis = DecisionTreeClassifier(max_depth=3).fit(X_vis, y_train)

plt.figure(figsize=(8, 6))
DecisionBoundaryDisplay.from_estimator(
    tree_vis, X_vis, cmap='Pastel1', response_method="predict"
)
plt.scatter(X_vis[:, 0], X_vis[:, 1], c=y_train, edgecolor='k', cmap='Set1')
plt.xlabel(iris.feature_names[2])
plt.ylabel(iris.feature_names[3])
plt.title('决策树决策边界')
plt.show()

→ 你会看到 阶梯状边界(非线性!)


九、剪枝(Pruning)—— 防过拟合

1. 预剪枝(训练时限制)

tree = DecisionTreeClassifier(
    max_depth=4,
    min_samples_split=5,    # 节点至少5个样本才分裂
    min_samples_leaf=2,     # 叶子节点至少2个样本
    max_features='sqrt'     # 每次只看部分特征
)

2. 后剪枝(Cost-Complexity Pruning)

path = tree.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas

best_tree = None
best_score = 0
for alpha in ccp_alphas[:-1]:  # 最后一个是根节点
    tree_tmp = DecisionTreeClassifier(random_state=42, ccp_alpha=alpha)
    tree_tmp.fit(X_train, y_train)
    score = tree_tmp.score(X_test, y_test)
    if score > best_score:
        best_score = score
        best_tree = tree_tmp

print(f"最佳 alpha: {best_tree.ccp_alpha:.5f}, 准确率: {best_score:.3f}")

十、特征重要性(可解释性强!)

importances = tree.feature_importances_
feat_names = iris.feature_names

for name, imp in zip(feat_names, importances):
    print(f"{name}: {imp:.3f}")

# 可视化
plt.barh(feat_names, importances)
plt.xlabel('重要性')
plt.title('特征重要性')
plt.show()

输出示例

sepal length (cm): 0.000
sepal width (cm): 0.014
petal length (cm): 0.577
petal width (cm): 0.409

花瓣长度最重要!


十一、优缺点总结

优点缺点
可视化、易解释易过拟合
不需标准化边界是轴对齐的
支持多分类/回归对噪声敏感
自动特征选择可能陷入局部最优

解决方案随机森林 = 100棵树投票 → 解决过拟合


十二、一键完整代码(复制即用)

# ===== 决策树完整流程 =====
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

# 1. 数据
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 2. 模型(预剪枝)
tree = DecisionTreeClassifier(max_depth=3, min_samples_leaf=2, random_state=42)
tree.fit(X_train, y_train)

# 3. 评估
print("准确率:", accuracy_score(y_test, tree.predict(X_test)))

# 4. 可视化树
plt.figure(figsize=(16, 8))
plot_tree(tree, feature_names=iris.feature_names, class_names=iris.target_names,
          filled=True, rounded=True, fontsize=11)
plt.title("决策树结构")
plt.show()

# 5. 特征重要性
plt.figure(figsize=(8, 4))
plt.barh(iris.feature_names, tree.feature_importances_)
plt.xlabel('重要性')
plt.title('特征重要性')
plt.show()

十三、总结公式卡

项目公式
$ H = -\sum p \log p $
信息增益$ IG = H_{\text{父}} – \sum w_i H_{\text{子}} $
基尼$ Gini = 1 – \sum p_i^2 $
回归分裂$ \text{方差下降} = Var_{\text{父}} – (w_L Var_L + w_R Var_R) $

你想深入哪一步?

  1. 随机森林(100棵树投票)
  2. 后剪枝 + 交叉验证选最佳树
  3. 用真实数据做项目(如泰坦尼克生存预测)
  4. 手写决策树算法(从分裂到预测)

回复 1–4,我立刻带你实战!

文章已创建 2481

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

相关文章

开始在上面输入您的搜索词,然后按回车进行搜索。按ESC取消。

返回顶部