Sklearn 模型保存与加载
在 Scikit-learn(sklearn)中,模型保存与加载是机器学习工作流的重要部分,允许用户将训练好的模型或 Pipeline 存储到磁盘,以便在未来重复使用,而无需重新训练。Scikit-learn 主要通过 joblib 库实现模型的序列化与反序列化。以下是关于 Sklearn 模型保存与加载 的详细指南,涵盖核心概念、方法、代码示例和最佳实践,力求简洁清晰。
一、核心概念
- 保存模型:将训练好的模型(或 Pipeline)序列化为文件,通常使用
joblib.dump。 - 加载模型:从文件中反序列化模型,使用
joblib.load,恢复模型以进行预测。 - 适用对象:不仅限于模型,还包括 Pipeline、预处理对象(如
StandardScaler)等任何实现 sklearn 估算器接口的对象。 - 文件格式:通常保存为
.pkl文件(Pickle 格式),但也可以使用其他扩展名。 - 依赖:
joblib是 Scikit-learn 的推荐工具,优于 Python 内置的pickle模块(效率更高,适合大型 NumPy 数组)。
二、安装与依赖
joblib 通常随 Scikit-learn 一起安装。如果缺失,可手动安装:
pip install joblib
三、保存与加载模型
3.1 保存模型
使用 joblib.dump 将模型保存到文件。
import joblib
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
# 数据准备
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 训练模型
model = LogisticRegression(random_state=42, max_iter=200)
model.fit(X_train, y_train)
# 保存模型
joblib.dump(model, 'logistic_model.pkl')
print("Model saved to logistic_model.pkl")
3.2 加载模型
使用 joblib.load 加载模型并进行预测。
# 加载模型
loaded_model = joblib.load('logistic_model.pkl')
# 预测
y_pred = loaded_model.predict(X_test)
from sklearn.metrics import accuracy_score
print(f"Accuracy of loaded model: {accuracy_score(y_test, y_pred):.3f}")
输出示例:
Model saved to logistic_model.pkl
Accuracy of loaded model: 1.000
四、保存与加载 Pipeline
Pipeline 包含预处理和模型步骤,保存后可直接用于新数据预测。
示例:保存和加载 Pipeline
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
# 创建 Pipeline
pipeline = Pipeline([
('scaler', StandardScaler()),
('classifier', RandomForestClassifier(random_state=42))
])
# 训练 Pipeline
pipeline.fit(X_train, y_train)
# 保存 Pipeline
joblib.dump(pipeline, 'pipeline_model.pkl')
# 加载 Pipeline
loaded_pipeline = joblib.load('pipeline_model.pkl')
# 预测
y_pred = loaded_pipeline.predict(X_test)
print(f"Pipeline Accuracy: {accuracy_score(y_test, y_pred):.3f}")
输出示例:
Pipeline Accuracy: 1.000
五、保存与加载 GridSearchCV 结果
GridSearchCV 的最佳模型可以通过 best_estimator_ 保存。
from sklearn.model_selection import GridSearchCV
# 定义 Pipeline 和参数网格
pipeline = Pipeline([
('scaler', StandardScaler()),
('classifier', LogisticRegression(random_state=42))
])
param_grid = {'classifier__C': [0.1, 1, 10]}
# 网格搜索
grid_search = GridSearchCV(pipeline, param_grid, cv=5)
grid_search.fit(X_train, y_train)
# 保存最佳模型
joblib.dump(grid_search.best_estimator_, 'best_model.pkl')
# 加载并预测
best_model = joblib.load('best_model.pkl')
y_pred = best_model.predict(X_test)
print(f"Best Model Accuracy: {accuracy_score(y_test, y_pred):.3f}")
输出示例:
Best Model Accuracy: 1.000
六、进阶用法
6.1 压缩模型文件
对于大型模型(如随机森林),可以使用 compress 参数减小文件大小。
# 保存时启用压缩
joblib.dump(pipeline, 'pipeline_compressed.pkl', compress=3) # compress 范围 0-9,值越大压缩越高
6.2 保存多个对象
可以将模型、预处理器等保存到一个字典中。
# 保存多个对象
model_dict = {
'model': model,
'scaler': StandardScaler().fit(X_train)
}
joblib.dump(model_dict, 'model_and_scaler.pkl')
# 加载
loaded_dict = joblib.load('model_and_scaler.pkl')
loaded_model = loaded_dict['model']
loaded_scaler = loaded_dict['scaler']
6.3 跨版本兼容性
- Scikit-learn 和
joblib的版本差异可能导致加载失败。建议记录保存时的版本:
import sklearn, joblib
print(f"Scikit-learn: {sklearn.__version__}, Joblib: {joblib.__version__}")
- 为兼容性,建议在新环境中安装相同版本:
pip install scikit-learn==1.5.2 joblib==1.4.2
七、注意事项
- 文件路径:
- 确保保存路径存在且有写入权限。
- 使用绝对路径(如
'/path/to/model.pkl')避免相对路径问题。
- 安全性:
joblib.load执行 Pickle 文件,可能有安全风险。仅加载可信文件。
- 模型完整性:
- 保存的模型包括所有训练参数(如权重、超参数),无需重新训练。
- Pipeline 保存时会保留所有步骤的拟合状态。
- 版本管理:
- 记录训练环境(Python、sklearn、joblib 版本)以确保加载时一致。
- 大模型优化:
- 对于超大模型,考虑
compress或部分保存(如仅保存模型权重)。
- 跨平台兼容性:
- Pickle 文件在不同操作系统间通常兼容,但需注意 Python 版本一致性。
八、完整示例(鸢尾花分类)
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import joblib
# 数据准备
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 创建并训练 Pipeline
pipeline = Pipeline([
('scaler', StandardScaler()),
('classifier', RandomForestClassifier(random_state=42))
])
pipeline.fit(X_train, y_train)
# 保存 Pipeline
joblib.dump(pipeline, 'iris_pipeline.pkl')
# 加载 Pipeline
loaded_pipeline = joblib.load('iris_pipeline.pkl')
# 预测并评估
y_pred = loaded_pipeline.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
输出示例:
Accuracy: 1.000
九、资源
- 官方文档:https://scikit-learn.org/stable/model_persistence.html
- Joblib 文档:https://joblib.readthedocs.io/en/stable/
- 社区:在 X 平台搜索
#scikit-learn获取最新讨论。
如果需要 可视化保存模型的性能、处理大型模型的优化技巧,或针对特定任务(如保存复杂 Pipeline)的代码,请告诉我,我可以提供更详细的示例!