TensorFlow 高级 API – Keras
TensorFlow 高级 API – Keras 简介
Keras 是 TensorFlow 2.x 集成的高级 API,用于快速构建、训练和部署机器学习模型。它简化了 TensorFlow 的复杂性,提供直观、用户友好的接口,适合初学者和快速原型开发,同时保留了灵活性以支持高级定制。以下是 Keras 的核心概念、功能和使用示例,重点介绍其在 TensorFlow 中的应用。
1. Keras 核心概念
- 层(Layers):Keras 模型由层组成,每层定义了数据如何转换(如全连接层、卷积层)。
- 模型(Models):Keras 提供三种模型构建方式:
- Sequential:线性堆叠层,适合简单模型。
- Functional API:支持复杂拓扑结构(如多输入/多输出模型)。
- Model 子类化:完全自定义模型,适合复杂逻辑。
- 编译(Compile):指定优化器、损失函数和评估指标。
- 训练(Fit):通过数据迭代优化模型参数。
- 回调(Callbacks):在训练过程中执行特定操作(如保存模型、早停)。
2. 主要功能
- 内置层:包括
Dense
(全连接)、Conv2D
(卷积)、LSTM
(循环神经网络)、Embedding
等。 - 数据预处理:支持
tf.keras.preprocessing
(如图像增强、文本分词)。 - 模型保存与加载:支持 HDF5 格式或 SavedModel 格式。
- 自定义能力:支持自定义层、损失函数、指标和模型。
- 与 TensorFlow 生态集成:无缝使用
tf.data
、TensorBoard 和 TensorFlow Hub。
3. 快速入门示例:使用 Sequential 模型
以下是一个使用 Keras Sequential
模型进行 MNIST 手写数字分类的示例:
import tensorflow as tf
from tensorflow.keras import layers, models
# 1. 加载和预处理数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 # 归一化到 [0,1]
# 2. 构建 Sequential 模型
model = models.Sequential([
layers.Flatten(input_shape=(28, 28)), # 展平 28x28 图像
layers.Dense(128, activation='relu'), # 隐藏层,128 神经元
layers.Dropout(0.2), # 防止过拟合
layers.Dense(10, activation='softmax') # 输出层,10 个类别
])
# 3. 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 4. 训练模型
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)
# 5. 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'测试准确率: {test_acc:.4f}')
输出:训练 5 个 epoch 后,测试准确率通常在 97%-98%。
解释:
- Flatten:将 28×28 图像展平为 784 维向量。
- Dense:全连接层,
relu
增加非线性,softmax
输出分类概率。 - Dropout:随机丢弃 20% 神经元,防止过拟合。
- compile:设置优化器(Adam)、损失函数(交叉熵)和指标(准确率)。
- fit:用 80% 训练数据训练,20% 用于验证。
4. 使用 Functional API
Functional API 适合复杂模型(如多输入/输出)。以下是一个示例:
from tensorflow.keras import Input, layers, Model
# 定义输入
inputs = Input(shape=(28, 28))
x = layers.Flatten()(inputs)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(10, activation='softmax')(x)
# 构建模型
model = Model(inputs=inputs, outputs=outputs)
# 编译和训练
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=32)
优势:可以定义非线性拓扑,如分支或共享层。
5. Model 子类化
适合高度自定义的模型,需定义 call
方法:
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.flatten = layers.Flatten()
self.dense1 = layers.Dense(128, activation='relu')
self.dropout = layers.Dropout(0.2)
self.dense2 = layers.Dense(10, activation='softmax')
def call(self, inputs):
x = self.flatten(inputs)
x = self.dense1(x)
x = self.dropout(x)
return self.dense2(x)
# 实例化并编译
model = MyModel()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=32)
适用场景:动态模型、复杂计算逻辑。
6. 常用 Keras 功能
常见层
- 卷积层:
layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')
- 池化层:
layers.MaxPooling2D(pool_size=(2, 2))
- 循环层:
layers.LSTM(units=64)
- 正则化:
layers.Dropout
,layers.BatchNormalization
数据预处理
- 图像增强:
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
rotation_range=10, zoom_range=0.1, horizontal_flip=True)
datagen.fit(x_train)
- 文本处理:
tf.keras.preprocessing.text.Tokenizer
回调(Callbacks)
控制训练过程:
callbacks = [
tf.keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True), # 保存最佳模型
tf.keras.callbacks.EarlyStopping(patience=3) # 早停
]
model.fit(x_train, y_train, epochs=10, callbacks=callbacks)
模型保存与加载
model.save('model.h5') # 保存模型
loaded_model = tf.keras.models.load_model('model.h5') # 加载模型
7. 与 TensorFlow 生态的集成
- tf.data:高效数据管道。
dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32).shuffle(1000)
model.fit(dataset, epochs=5)
- TensorBoard:可视化训练过程。
callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./logs')]
model.fit(x_train, y_train, epochs=5, callbacks=callbacks)
# 运行:tensorboard --logdir ./logs
- TensorFlow Hub:加载预训练模型。
import tensorflow_hub as hub
model = tf.keras.Sequential([hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v2_50/classification/5")])
8. 性能优化
- 批量大小:调整
batch_size
(如 32、64)平衡速度和内存。 - 混合精度训练:
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
- GPU 加速:确保 TensorFlow GPU 版本正确安装(见环境搭建)。
9. 总结
Keras 作为 TensorFlow 的高级 API,通过 Sequential
、Functional API 和 Model 子类化提供了从简单到复杂的建模能力。其内置层、预处理工具和回调功能简化了开发流程,与 tf.data
和 TensorBoard 等生态工具无缝集成。
如果你需要更具体的 Keras 示例(如 CNN、RNN)、某功能深入讲解,或想生成训练过程中的损失/准确率图表,请告诉我!