Sklearn 机器学习模型
Scikit-learn(sklearn)提供了丰富且易用的机器学习模型,覆盖监督学习(分类、回归)和无监督学习(聚类、降维)。所有模型遵循统一接口:fit()、predict()、score(),便于快速切换和实验。
以下是 sklearn 核心模型分类、常用模型、代码示例和选型建议,适合快速掌握。
一、模型分类总览
| 任务类型 | 模块路径 | 典型模型 |
|---|---|---|
| 分类 | sklearn.linear_model / ensemble | 逻辑回归、SVM、随机森林 |
| 回归 | sklearn.linear_model / ensemble | 线性回归、岭回归、GBDT |
| 聚类 | sklearn.cluster | K-Means、DBSCAN |
| 降维 | sklearn.decomposition | PCA、t-SNE |
| 模型选择 | sklearn.model_selection | GridSearchCV、交叉验证 |
二、常用模型详解 + 代码示例
1. 分类模型(Classification)
| 模型 | 适用场景 | 代码示例 |
|---|---|---|
| 逻辑回归 | 线性可分、二分类、概率输出 | “`python |
| from sklearn.linear_model import LogisticRegression | ||
| model = LogisticRegression(max_iter=200) | ||
| model.fit(X_train, y_train) |
| **支持向量机(SVM)** | 小样本、非线性(用核函数) | ```python
from sklearn.svm import SVC
model = SVC(kernel='rbf', C=1.0)
model.fit(X_train, y_train)
|
| 随机森林 | 特征多、鲁棒性强、不易过拟合 | “`python
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
| **梯度提升树(GBDT)** | 高精度、竞赛常用 | ```python
from sklearn.ensemble import GradientBoostingClassifier
model = GradientBoostingClassifier(learning_rate=0.1)
model.fit(X_train, y_train)
|
| K近邻(KNN) | 小数据集、懒惰学习 | “`python
from sklearn.neighbors import KNeighborsClassifier
model = KNeighborsClassifier(n_neighbors=5)
model.fit(X_train, y_train)
---
### 2. 回归模型(Regression)
| 模型 | 适用场景 | 代码示例 |
|------|---------|--------|
| **线性回归** | 线性关系、解释性强 | ```python
from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(X_train, y_train)
|
| 岭回归(Ridge) | 多重共线性、防止过拟合 | “`python
from sklearn.linear_model import Ridge
model = Ridge(alpha=1.0)
model.fit(X_train, y_train)
| **Lasso 回归** | 特征选择、稀 Sparse 模型 | ```python
from sklearn.linear_model import Lasso
model = Lasso(alpha=0.1)
model.fit(X_train, y_train)
|
| 随机森林回归 | 非线性、鲁棒性好 | “`python
from sklearn.ensemble import RandomForestRegressor
model = RandomForestRegressor(n_estimators=100)
model.fit(X_train, y_train)
---
### 3. 聚类模型(Clustering)——无监督
| 模型 | 适用场景 | 代码示例 |
|------|---------|--------|
| **K-Means** | 球形簇、已知簇数 | ```python
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=3, random_state=42)
clusters = kmeans.fit_predict(X)
|
| DBSCAN | 任意形状簇、自动发现噪声 | “`python
from sklearn.cluster import DBSCAN
dbscan = DBSCAN(eps=0.5, min_samples=5)
clusters = dbscan.fit_predict(X)
| **层次聚类** | 树状结构、可视化 | ```python
from sklearn.cluster import AgglomerativeClustering
model = AgglomerativeClustering(n_clusters=3)
clusters = model.fit_predict(X)
|
4. 降维模型(Dimensionality Reduction)
| 模型 | 适用场景 | 代码示例 |
|---|---|---|
| PCA(主成分分析) | 去相关、降维、可视化 | “`python |
| from sklearn.decomposition import PCA | ||
| pca = PCA(n_components=2) | ||
| X_pca = pca.fit_transform(X) |
| **t-SNE** | 高维数据可视化(非线性) | ```python
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2)
X_tsne = tsne.fit_transform(X)
|
三、完整示例:鸢尾花分类(多模型对比)
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
# 模型列表
models = {
'Logistic Regression': LogisticRegression(max_iter=200),
'Random Forest': RandomForestClassifier(n_estimators=100),
'SVM': SVC(kernel='rbf'),
'KNN': KNeighborsClassifier(n_neighbors=5)
}
# 数据准备
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 标准化(对 SVM、KNN 重要)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 训练并评估
for name, model in models.items():
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"{name}: Accuracy = {acc:.3f}")
输出示例:
Logistic Regression: Accuracy = 1.000
Random Forest: Accuracy = 1.000
SVM: Accuracy = 1.000
KNN: Accuracy = 1.000
四、模型评估指标
| 任务 | 常用指标 | 代码 |
|---|---|---|
| 分类 | 准确率、精确率、召回率、F1、AUC | “`python |
| from sklearn.metrics import classification_report, roc_auc_score | ||
| print(classification_report(y_test, y_pred)) |
| 回归 | MSE、MAE、R² | ```python
from sklearn.metrics import mean_squared_error, r2_score
print("R²:", r2_score(y_test, y_pred))
|
| 聚类 | 轮廓系数、ARI | “`python
from sklearn.metrics import silhouette_score
print(“Silhouette:”, silhouette_score(X, clusters))
---
## 五、模型选择建议(选型表)
| 场景 | 推荐模型 |
|------|---------|
| 数据量小、解释性强 | 逻辑回归、线性回归 |
| 特征多、非线性 | 随机森林、GBDT |
| 高维稀疏文本 | 朴素贝叶斯、SVM |
| 图像/小样本 | SVM、KNN |
| 客户分群 | K-Means、DBSCAN |
| 可视化高维数据 | PCA、t-SNE |
---
## 六、模型保存与加载
python
import joblib
保存
joblib.dump(model, ‘rf_model.pkl’)
加载
loaded_model = joblib.load(‘rf_model.pkl’)
y_pred = loaded_model.predict(X_new)
---
## 七、进阶:Pipeline + 网格搜索
python
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
pipeline = Pipeline([
(‘scaler’, StandardScaler()),
(‘rf’, RandomForestClassifier(random_state=42))
])
param_grid = {
‘rf__n_estimators’: [50, 100, 200],
‘rf__max_depth’: [None, 10, 20]
}
grid = GridSearchCV(pipeline, param_grid, cv=5)
grid.fit(X_train, y_train)
print(“Best params:”, grid.best_params_)
“`
八、官方资源
- 模型总览:https://scikit-learn.org/stable/supervised_learning.html
- 算法选择图:https://scikit-learn.org/stable/tutorial/machine_learning_map/index.html
- 示例代码:https://scikit-learn.org/stable/auto_examples/
总结:
sklearn 模型接口统一、文档清晰、生态完善,是机器学习入门的首选。
建议:从逻辑回归 → 随机森林 → Pipeline + GridSearch 逐步掌握。
需要我为你生成 模型对比图表、可视化代码、或特定任务的完整项目模板?告诉我你的任务(比如“房价预测”或“客户分群”),我可以给出完整代码!