TensorFlow 高级 API – Keras

TensorFlow 高级 API – Keras 简介

Keras 是 TensorFlow 2.x 集成的高级 API,用于快速构建、训练和部署机器学习模型。它简化了 TensorFlow 的复杂性,提供直观、用户友好的接口,适合初学者和快速原型开发,同时保留了灵活性以支持高级定制。以下是 Keras 的核心概念、功能和使用示例,重点介绍其在 TensorFlow 中的应用。


1. Keras 核心概念

  • 层(Layers):Keras 模型由层组成,每层定义了数据如何转换(如全连接层、卷积层)。
  • 模型(Models):Keras 提供三种模型构建方式:
  • Sequential:线性堆叠层,适合简单模型。
  • Functional API:支持复杂拓扑结构(如多输入/多输出模型)。
  • Model 子类化:完全自定义模型,适合复杂逻辑。
  • 编译(Compile):指定优化器、损失函数和评估指标。
  • 训练(Fit):通过数据迭代优化模型参数。
  • 回调(Callbacks):在训练过程中执行特定操作(如保存模型、早停)。

2. 主要功能

  • 内置层:包括 Dense(全连接)、Conv2D(卷积)、LSTM(循环神经网络)、Embedding 等。
  • 数据预处理:支持 tf.keras.preprocessing(如图像增强、文本分词)。
  • 模型保存与加载:支持 HDF5 格式或 SavedModel 格式。
  • 自定义能力:支持自定义层、损失函数、指标和模型。
  • 与 TensorFlow 生态集成:无缝使用 tf.data、TensorBoard 和 TensorFlow Hub。

3. 快速入门示例:使用 Sequential 模型

以下是一个使用 Keras Sequential 模型进行 MNIST 手写数字分类的示例:

import tensorflow as tf
from tensorflow.keras import layers, models

# 1. 加载和预处理数据
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # 归一化到 [0,1]

# 2. 构建 Sequential 模型
model = models.Sequential([
    layers.Flatten(input_shape=(28, 28)),  # 展平 28x28 图像
    layers.Dense(128, activation='relu'),  # 隐藏层,128 神经元
    layers.Dropout(0.2),                   # 防止过拟合
    layers.Dense(10, activation='softmax') # 输出层,10 个类别
])

# 3. 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 4. 训练模型
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)

# 5. 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'测试准确率: {test_acc:.4f}')

输出:训练 5 个 epoch 后,测试准确率通常在 97%-98%。

解释

  • Flatten:将 28×28 图像展平为 784 维向量。
  • Dense:全连接层,relu 增加非线性,softmax 输出分类概率。
  • Dropout:随机丢弃 20% 神经元,防止过拟合。
  • compile:设置优化器(Adam)、损失函数(交叉熵)和指标(准确率)。
  • fit:用 80% 训练数据训练,20% 用于验证。

4. 使用 Functional API

Functional API 适合复杂模型(如多输入/输出)。以下是一个示例:

from tensorflow.keras import Input, layers, Model

# 定义输入
inputs = Input(shape=(28, 28))
x = layers.Flatten()(inputs)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(10, activation='softmax')(x)

# 构建模型
model = Model(inputs=inputs, outputs=outputs)

# 编译和训练
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=32)

优势:可以定义非线性拓扑,如分支或共享层。


5. Model 子类化

适合高度自定义的模型,需定义 call 方法:

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = layers.Flatten()
        self.dense1 = layers.Dense(128, activation='relu')
        self.dropout = layers.Dropout(0.2)
        self.dense2 = layers.Dense(10, activation='softmax')

    def call(self, inputs):
        x = self.flatten(inputs)
        x = self.dense1(x)
        x = self.dropout(x)
        return self.dense2(x)

# 实例化并编译
model = MyModel()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=32)

适用场景:动态模型、复杂计算逻辑。


6. 常用 Keras 功能

常见层

  • 卷积层layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')
  • 池化层layers.MaxPooling2D(pool_size=(2, 2))
  • 循环层layers.LSTM(units=64)
  • 正则化layers.Dropout, layers.BatchNormalization

数据预处理

  • 图像增强
  datagen = tf.keras.preprocessing.image.ImageDataGenerator(
      rotation_range=10, zoom_range=0.1, horizontal_flip=True)
  datagen.fit(x_train)
  • 文本处理tf.keras.preprocessing.text.Tokenizer

回调(Callbacks)

控制训练过程:

callbacks = [
    tf.keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True),  # 保存最佳模型
    tf.keras.callbacks.EarlyStopping(patience=3)  # 早停
]
model.fit(x_train, y_train, epochs=10, callbacks=callbacks)

模型保存与加载

model.save('model.h5')  # 保存模型
loaded_model = tf.keras.models.load_model('model.h5')  # 加载模型

7. 与 TensorFlow 生态的集成

  • tf.data:高效数据管道。
  dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32).shuffle(1000)
  model.fit(dataset, epochs=5)
  • TensorBoard:可视化训练过程。
  callbacks = [tf.keras.callbacks.TensorBoard(log_dir='./logs')]
  model.fit(x_train, y_train, epochs=5, callbacks=callbacks)
  # 运行:tensorboard --logdir ./logs
  • TensorFlow Hub:加载预训练模型。
  import tensorflow_hub as hub
  model = tf.keras.Sequential([hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v2_50/classification/5")])

8. 性能优化

  • 批量大小:调整 batch_size(如 32、64)平衡速度和内存。
  • 混合精度训练
  from tensorflow.keras import mixed_precision
  policy = mixed_precision.Policy('mixed_float16')
  mixed_precision.set_global_policy(policy)
  • GPU 加速:确保 TensorFlow GPU 版本正确安装(见环境搭建)。

9. 总结

Keras 作为 TensorFlow 的高级 API,通过 Sequential、Functional API 和 Model 子类化提供了从简单到复杂的建模能力。其内置层、预处理工具和回调功能简化了开发流程,与 tf.data 和 TensorBoard 等生态工具无缝集成。

如果你需要更具体的 Keras 示例(如 CNN、RNN)、某功能深入讲解,或想生成训练过程中的损失/准确率图表,请告诉我!

类似文章

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注