TensorFlow 实例 – 文本分类项目
TensorFlow 实例 – 文本分类项目
本教程将展示如何使用 TensorFlow 和 Keras 构建一个完整的文本分类项目,以 IMDB 电影评论数据集为例,目标是进行情感分析(区分正面和负面评论)。教程涵盖数据加载、预处理、模型构建、训练、评估和结果可视化,适合初学者和需要实用示例的用户。代码简洁,包含调优技巧和性能优化。如果需要更复杂的模型、其他数据集或特定功能(如使用 BERT),请告诉我!
1. 项目目标
- 任务:对 IMDB 数据集中的电影评论进行情感分类,分为正面(1)和负面(0)。
- 数据集:IMDB,包含 25,000 条训练评论和 25,000 条测试评论,每条评论已编码为词索引序列。
- 输出:分类模型,预测评论的情感倾向,并可视化训练过程和结果。
2. 环境准备
确保安装 TensorFlow 和相关库:
pip install tensorflow matplotlib
验证 TensorFlow:
import tensorflow as tf
print(tf.__version__) # 确保版本为 2.x(如 2.17.0)
3. 完整代码
以下是完整的文本分类项目代码,包含数据处理、模型构建、训练和评估:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing.sequence import pad_sequences
import matplotlib.pyplot as plt
import numpy as np
# 1. 加载和预处理数据
vocab_size = 10000 # 词汇表大小(只保留最常见词)
max_len = 200 # 序列最大长度
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=vocab_size)
# 填充序列到统一长度
x_train = pad_sequences(x_train, maxlen=max_len, padding='post')
x_test = pad_sequences(x_test, maxlen=max_len, padding='post')
# 2. 创建数据管道
train_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(buffer_size=1000)
.batch(batch_size=32)
.prefetch(tf.data.AUTOTUNE))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
# 3. 构建 LSTM 模型
model = models.Sequential([
layers.Embedding(vocab_size, 128, input_length=max_len), # 词嵌入层
layers.LSTM(64, return_sequences=False), # LSTM 层
layers.Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01)),
layers.Dropout(0.5), # 防止过拟合
layers.Dense(1, activation='sigmoid') # 二分类输出
])
# 4. 编译模型
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='binary_crossentropy',
metrics=['accuracy']
)
# 5. 训练模型
callbacks = [
tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
tf.keras.callbacks.ModelCheckpoint('imdb_model.h5', save_best_only=True),
tf.keras.callbacks.TensorBoard(log_dir='./logs')
]
history = model.fit(
train_dataset,
epochs=10,
validation_data=test_dataset,
callbacks=callbacks
)
# 6. 评估模型
test_loss, test_acc = model.evaluate(test_dataset)
print(f'\n测试集准确率: {test_acc:.4f}')
# 7. 可视化训练过程
plt.figure(figsize=(12, 4))
# 绘制准确率
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
# 绘制损失
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
# 8. 可视化预测结果
predictions = model.predict(test_dataset)
predicted_labels = (predictions > 0.5).astype(int).flatten()
# 获取词典以解码评论
word_index = imdb.get_word_index()
reverse_word_index = {value: key for key, value in word_index.items()}
def decode_review(sequence):
return ' '.join([reverse_word_index.get(i - 3, '?') for i in sequence if i > 2])
# 显示前 5 条测试评论及其预测
for i in range(5):
print(f'\n评论 {i + 1}:')
print(decode_review(x_test[i]))
print(f'预测情感: {"正面" if predicted_labels[i] == 1 else "负面"}')
print(f'真实情感: {"正面" if y_test[i] == 1 else "负面"}')
4. 代码逐部分解释
4.1 数据加载与预处理
- IMDB 数据集:包含 25,000 条训练和测试评论,评论已分词为词索引(整数序列)。
vocab_size=10000
:只保留最常见的 10,000 个词。pad_sequences
:将序列填充或截断到固定长度(max_len=200
)。- 数据管道:使用
tf.data
实现打乱、批处理和预取,优化训练效率。
4.2 模型结构
- Embedding 层:将词索引转为 128 维密集向量,表示词的语义。
- LSTM 层:捕捉评论中的序列信息,适合长文本处理。
- Dense 层:结合 ReLU 激活和 L2 正则化,提取高级特征。
- Dropout:随机丢弃 50% 神经元,防止过拟合。
- 输出层:使用 sigmoid 激活,输出正面(1)或负面(0)的概率。
4.3 编译与训练
- 优化器:Adam,学习率为 0.001,适合大多数 NLP 任务。
- 损失函数:
binary_crossentropy
,适用于二分类任务。 - 回调:
EarlyStopping
:验证损失 3 个 epoch 无改进则停止。ModelCheckpoint
:保存最佳模型。TensorBoard
:记录训练日志,可视化指标。
4.4 评估与可视化
- 评估:在测试集上计算损失和准确率。
- 可视化:
- 绘制训练和验证的准确率/损失曲线,检查模型性能。
- 显示前 5 条测试评论的原文、预测情感和真实情感。
5. 运行结果
- 训练:10 个 epoch(可能因早停提前结束),测试准确率通常在 85%-90%。
- 可视化:
- 准确率和损失曲线反映训练过程,验证准确率接近训练准确率表明泛化良好。
- 预测结果显示评论文本及其情感分类,方便检查模型表现。
示例输出:
测试集准确率: 0.8750
评论 1:
the film was good but the story was too predictable
预测情感: 负面
真实情感: 负面
评论 2:
amazing movie with great acting and a powerful message
预测情感: 正面
真实情感: 正面
6. 生成图表
以下是训练过程中准确率和损失的示例图表:
{
"type": "line",
"data": {
"labels": ["Epoch 1", "Epoch 2", "Epoch 3", "Epoch 4", "Epoch 5"],
"datasets": [
{
"label": "Training Accuracy",
"data": [0.65, 0.75, 0.82, 0.85, 0.87], // 示例数据
"borderColor": "#1f77b4",
"fill": false
},
{
"label": "Validation Accuracy",
"data": [0.68, 0.76, 0.80, 0.83, 0.85], // 示例数据
"borderColor": "#ff7f0e",
"fill": false
}
]
},
"options": {
"scales": {
"x": { "title": { "display": true, "text": "Epoch" } },
"y": { "title": { "display": true, "text": "Accuracy" }, "beginAtZero": false }
}
}
}
说明:实际数据来自 history.history['accuracy']
和 history.history['val_accuracy']
。损失曲线类似,可从 history.history['loss']
和 history.history['val_loss']
获取。
7. 优化建议
- 提高准确率:
- 使用更复杂的模型,如双向 LSTM 或 Transformer:
python model.add(layers.Bidirectional(layers.LSTM(64)))
- 引入预训练嵌入(如 BERT):
python import tensorflow_hub as hub model = models.Sequential([ hub.KerasLayer("https://tfhub.dev/google/nnlm-en-dim128/2"), layers.Dense(1, activation='sigmoid') ])
- 数据增强:
- 使用同义词替换或随机删除词(需借助
nlpaug
等库)。 - 加速训练:
- 启用混合精度训练:
python from tensorflow.keras import mixed_precision mixed_precision.set_global_policy('mixed_float16')
- 使用多 GPU:
tf.distribute.MirroredStrategy
。 - 防止过拟合:
- 增加
Dropout
比例或 L2 正则化强度。 - 缩短序列长度(
max_len
)或限制词汇表(vocab_size
)。 - 监控性能:
- 使用 TensorBoard 查看详细指标:
tensorboard --logdir ./logs
。 - 检查混淆矩阵:
python from sklearn.metrics import confusion_matrix import seaborn as sns cm = confusion_matrix(y_test, predicted_labels) sns.heatmap(cm, annot=True, fmt='d', xticklabels=['Negative', 'Positive'], yticklabels=['Negative', 'Positive']) plt.show()
8. 常见问题与解决
- 准确率低:
- 增加 epoch 数或调整学习率(
learning_rate=0.0001
)。 - 使用预训练嵌入或更复杂的模型。
- 过拟合:
- 验证准确率低于训练准确率,增加正则化或数据增强。
- 训练慢:
- 确保 GPU 可用:
tf.config.list_physical_devices('GPU')
。 - 优化数据管道(
prefetch
,cache
)。 - 内存不足:
- 减小
batch_size
。 - 使用 TFRecord 存储数据:
python def create_tfrecord(text, label): feature = { 'text': tf.train.Feature(bytes_list=tf.train.BytesList(value=[text.encode()])), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])) } return tf.train.Example(features=tf.train.Features(feature=feature))
9. 总结
本项目展示了使用 TensorFlow 和 Keras 进行文本分类的完整流程,基于 IMDB 数据集的情感分析。示例中的 LSTM 模型结合了词嵌入、正则化和高效数据管道,适合 NLP 入门。如果需要进一步优化(如使用 Transformer、BERT 或处理自定义文本数据),或想生成更多图表(如混淆矩阵、词嵌入可视化),请告诉我!
需要更多内容?
部署模型(TensorFlow Lite 或 Serving)。
更复杂的模型(如 Transformer、BERT)。
其他数据集(如 Twitter 情感分析、自定义文本)。
额外图表(如混淆矩阵、ROC 曲线)。