Sklearn 模型保存与加载

在 Scikit-learn(sklearn)中,模型保存与加载是机器学习工作流的重要部分,允许用户将训练好的模型或 Pipeline 存储到磁盘,以便在未来重复使用,而无需重新训练。Scikit-learn 主要通过 joblib 库实现模型的序列化与反序列化。以下是关于 Sklearn 模型保存与加载 的详细指南,涵盖核心概念、方法、代码示例和最佳实践,力求简洁清晰。


一、核心概念

  • 保存模型:将训练好的模型(或 Pipeline)序列化为文件,通常使用 joblib.dump
  • 加载模型:从文件中反序列化模型,使用 joblib.load,恢复模型以进行预测。
  • 适用对象:不仅限于模型,还包括 Pipeline、预处理对象(如 StandardScaler)等任何实现 sklearn 估算器接口的对象。
  • 文件格式:通常保存为 .pkl 文件(Pickle 格式),但也可以使用其他扩展名。
  • 依赖joblib 是 Scikit-learn 的推荐工具,优于 Python 内置的 pickle 模块(效率更高,适合大型 NumPy 数组)。

二、安装与依赖

joblib 通常随 Scikit-learn 一起安装。如果缺失,可手动安装:

pip install joblib

三、保存与加载模型

3.1 保存模型

使用 joblib.dump 将模型保存到文件。

import joblib
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# 数据准备
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 训练模型
model = LogisticRegression(random_state=42, max_iter=200)
model.fit(X_train, y_train)

# 保存模型
joblib.dump(model, 'logistic_model.pkl')
print("Model saved to logistic_model.pkl")

3.2 加载模型

使用 joblib.load 加载模型并进行预测。

# 加载模型
loaded_model = joblib.load('logistic_model.pkl')

# 预测
y_pred = loaded_model.predict(X_test)
from sklearn.metrics import accuracy_score
print(f"Accuracy of loaded model: {accuracy_score(y_test, y_pred):.3f}")

输出示例

Model saved to logistic_model.pkl
Accuracy of loaded model: 1.000

四、保存与加载 Pipeline

Pipeline 包含预处理和模型步骤,保存后可直接用于新数据预测。

示例:保存和加载 Pipeline

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier

# 创建 Pipeline
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('classifier', RandomForestClassifier(random_state=42))
])

# 训练 Pipeline
pipeline.fit(X_train, y_train)

# 保存 Pipeline
joblib.dump(pipeline, 'pipeline_model.pkl')

# 加载 Pipeline
loaded_pipeline = joblib.load('pipeline_model.pkl')

# 预测
y_pred = loaded_pipeline.predict(X_test)
print(f"Pipeline Accuracy: {accuracy_score(y_test, y_pred):.3f}")

输出示例

Pipeline Accuracy: 1.000

五、保存与加载 GridSearchCV 结果

GridSearchCV 的最佳模型可以通过 best_estimator_ 保存。

from sklearn.model_selection import GridSearchCV

# 定义 Pipeline 和参数网格
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('classifier', LogisticRegression(random_state=42))
])
param_grid = {'classifier__C': [0.1, 1, 10]}

# 网格搜索
grid_search = GridSearchCV(pipeline, param_grid, cv=5)
grid_search.fit(X_train, y_train)

# 保存最佳模型
joblib.dump(grid_search.best_estimator_, 'best_model.pkl')

# 加载并预测
best_model = joblib.load('best_model.pkl')
y_pred = best_model.predict(X_test)
print(f"Best Model Accuracy: {accuracy_score(y_test, y_pred):.3f}")

输出示例

Best Model Accuracy: 1.000

六、进阶用法

6.1 压缩模型文件

对于大型模型(如随机森林),可以使用 compress 参数减小文件大小。

# 保存时启用压缩
joblib.dump(pipeline, 'pipeline_compressed.pkl', compress=3)  # compress 范围 0-9,值越大压缩越高

6.2 保存多个对象

可以将模型、预处理器等保存到一个字典中。

# 保存多个对象
model_dict = {
    'model': model,
    'scaler': StandardScaler().fit(X_train)
}
joblib.dump(model_dict, 'model_and_scaler.pkl')

# 加载
loaded_dict = joblib.load('model_and_scaler.pkl')
loaded_model = loaded_dict['model']
loaded_scaler = loaded_dict['scaler']

6.3 跨版本兼容性

  • Scikit-learn 和 joblib 的版本差异可能导致加载失败。建议记录保存时的版本:
  import sklearn, joblib
  print(f"Scikit-learn: {sklearn.__version__}, Joblib: {joblib.__version__}")
  • 为兼容性,建议在新环境中安装相同版本:
  pip install scikit-learn==1.5.2 joblib==1.4.2

七、注意事项

  1. 文件路径
  • 确保保存路径存在且有写入权限。
  • 使用绝对路径(如 '/path/to/model.pkl')避免相对路径问题。
  1. 安全性
  • joblib.load 执行 Pickle 文件,可能有安全风险。仅加载可信文件。
  1. 模型完整性
  • 保存的模型包括所有训练参数(如权重、超参数),无需重新训练。
  • Pipeline 保存时会保留所有步骤的拟合状态。
  1. 版本管理
  • 记录训练环境(Python、sklearn、joblib 版本)以确保加载时一致。
  1. 大模型优化
  • 对于超大模型,考虑 compress 或部分保存(如仅保存模型权重)。
  1. 跨平台兼容性
  • Pickle 文件在不同操作系统间通常兼容,但需注意 Python 版本一致性。

八、完整示例(鸢尾花分类)

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import joblib

# 数据准备
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建并训练 Pipeline
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('classifier', RandomForestClassifier(random_state=42))
])
pipeline.fit(X_train, y_train)

# 保存 Pipeline
joblib.dump(pipeline, 'iris_pipeline.pkl')

# 加载 Pipeline
loaded_pipeline = joblib.load('iris_pipeline.pkl')

# 预测并评估
y_pred = loaded_pipeline.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")

输出示例

Accuracy: 1.000

九、资源

  • 官方文档:https://scikit-learn.org/stable/model_persistence.html
  • Joblib 文档:https://joblib.readthedocs.io/en/stable/
  • 社区:在 X 平台搜索 #scikit-learn 获取最新讨论。

如果需要 可视化保存模型的性能处理大型模型的优化技巧,或针对特定任务(如保存复杂 Pipeline)的代码,请告诉我,我可以提供更详细的示例!

类似文章

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注