PyTorch 模型部署

PyTorch 模型部署是将训练好的模型应用于生产环境(如 Web 应用、移动端、嵌入式设备等)的过程。PyTorch 提供了多种工具和方法来简化模型部署,包括模型保存/加载、ONNX 导出、TorchScript、以及与推理框架(如 TorchServe、TensorRT)的集成。以下是关于 PyTorch 模型部署的详细指南,涵盖核心步骤、方法和示例,力求简洁且实用。


1. PyTorch 模型部署概述

模型部署的目标是将训练好的 PyTorch 模型高效、可靠地应用于实际场景。常见部署场景包括:

  • 服务器端推理:在云端或本地服务器上运行模型(如 API 服务)。
  • 边缘设备:在移动设备或嵌入式硬件上运行(如手机、IoT 设备)。
  • 跨框架兼容:将模型导出为 ONNX 格式,供其他框架(如 TensorFlow、ONNX Runtime)使用。
  • 高性能推理:使用优化工具(如 NVIDIA TensorRT)加速推理。

PyTorch 提供了以下工具支持部署:

  • 模型保存与加载:通过 torch.savetorch.load
  • TorchScript:将模型转换为可序列化的中间表示(IR)。
  • ONNX:导出为跨平台格式。
  • TorchServe:用于服务器端部署的专用工具。
  • LibTorch:C++ 接口,适合嵌入式或高性能场景。

2. 核心部署步骤

(1) 保存与加载模型

PyTorch 模型通常通过保存权重或整个模型进行序列化。

  • 保存权重(推荐)
    保存模型的 state_dict,占用空间小,灵活性高。
  import torch
  import torch.nn as nn

  # 示例模型
  model = nn.Linear(10, 2)
  torch.save(model.state_dict(), 'model_weights.pth')

  # 加载权重
  model.load_state_dict(torch.load('model_weights.pth'))
  model.eval()  # 切换到推理模式
  • 保存整个模型
    保存模型结构和权重,但文件较大,兼容性稍差。
  torch.save(model, 'model.pth')
  model = torch.load('model.pth')
  model.eval()
  • 注意事项
  • 使用 model.eval() 禁用 Dropout 和 BatchNorm 的训练行为。
  • 确保设备一致性(CPU/GPU),使用 map_location 参数:
    python model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))

(2) TorchScript 部署

TorchScript 将 PyTorch 模型转换为可序列化的中间表示,支持 Python 环境外的推理(如 C++)。

  • 两种转换方式
  • Tracing:通过示例输入追踪模型,适合简单模型。
    python model.eval() example_input = torch.randn(1, 10) traced_model = torch.jit.trace(model, example_input) traced_model.save('model_traced.pt')
  • Scripting:编译整个模型,适合包含控制流的复杂模型。 scripted_model = torch.jit.script(model) scripted_model.save('model_scripted.pt')
  • 加载与推理
  loaded_model = torch.jit.load('model_traced.pt')
  output = loaded_model(torch.randn(1, 10))
  • C++ 推理
    使用 LibTorch 加载 TorchScript 模型:
  #include <torch/script.h>
  auto model = torch::jit::load("model_traced.pt");
  std::vector<torch::jit::IValue> inputs;
  inputs.push_back(torch::randn({1, 10}));
  auto output = model.forward(inputs).toTensor();

(3) ONNX 导出

ONNX(Open Neural Network Exchange)是一种跨框架模型格式,适合在不同平台(如 ONNX Runtime、TensorFlow)上部署。

  • 导出 ONNX 模型
  model.eval()
  dummy_input = torch.randn(1, 10)
  torch.onnx.export(
      model,
      dummy_input,
      'model.onnx',
      input_names=['input'],
      output_names=['output'],
      dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
  )
  • 推理(使用 ONNX Runtime)
  import onnxruntime as ort
  import numpy as np

  session = ort.InferenceSession('model.onnx')
  input_name = session.get_inputs()[0].name
  output = session.run(None, {input_name: np.random.randn(1, 10).astype(np.float32)})[0]
  • 注意事项
  • 确保模型支持 ONNX 算子(部分自定义操作可能不兼容)。
  • 指定 dynamic_axes 支持动态批大小。
  • 检查 ONNX 模型:pip install onnxonnx.checker.check_model(model)

(4) TorchServe 部署

TorchServe 是 PyTorch 官方提供的模型服务框架,适合服务器端部署。

  • 安装
  pip install torchserve torch-model-archiver
  • 打包模型
  torch-model-archiver --model-name my_model \
      --version 1.0 \
      --model-file model.py \
      --serialized-file model_weights.pth \
      --handler image_classifier \
      --extra-files index_to_name.json
  • 启动服务
  torchserve --start --model-store model_store --models my_model=my_model.mar
  • 推理请求
    使用 HTTP 请求调用模型:
  import requests
  response = requests.post('http://localhost:8080/predictions/my_model', data=open('image.jpg', 'rb'))
  print(response.json())

(5) 优化与加速

  • 混合精度推理
    使用 torch.cuda.amp 加速推理:
  from torch.cuda.amp import autocast
  model.eval()
  with torch.no_grad(), autocast():
      output = model(input.to('cuda'))
  • NVIDIA TensorRT
    转换 ONNX 模型为 TensorRT 引擎以加速 GPU 推理:
  import tensorrt as trt
  import numpy as np
  import pycuda.driver as cuda
  import pycuda.autoinit

  # 加载 ONNX 模型并转换为 TensorRT 引擎(需要额外配置)
  • 量化
    减少模型大小和推理时间:
  • 动态量化
    python quantized_model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
  • 静态量化:需要校准数据,适合卷积网络。

3. 完整示例:部署图像分类模型

以下是一个从训练到部署的完整流程,使用 ResNet18 进行图像分类并导出为 ONNX 和 TorchScript。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# 超参数
num_classes = 10
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据准备
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 模型
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.to(device)

# 训练
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(2):  # 简化示例
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 保存权重
torch.save(model.state_dict(), 'resnet18_cifar10.pth')

# 导出 TorchScript
model.eval()
example_input = torch.randn(1, 3, 224, 224).to(device)
traced_model = torch.jit.trace(model, example_input)
traced_model.save('resnet18_cifar10.pt')

# 导出 ONNX
torch.onnx.export(
    model,
    example_input,
    'resnet18_cifar10.onnx',
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

# 推理示例
loaded_model = torch.jit.load('resnet18_cifar10.pt')
with torch.no_grad():
    output = loaded_model(torch.randn(1, 3, 224, 224).to(device))

4. 常见问题与注意事项

  • 设备兼容性:加载模型时确保设备一致,GPU 训练的模型可能需 map_location 转换为 CPU。
  • 推理模式:使用 model.eval()torch.no_grad() 禁用训练行为和梯度计算。
  • 批大小:ONNX 和 TorchScript 支持动态批大小,但需正确配置 dynamic_axes
  • 模型压缩
  • 使用量化(torch.quantization)或剪枝(torch.nn.utils.prune)减小模型大小。
  • 检查推理环境是否支持 FP16 或 INT8。
  • 版本兼容性:确保 PyTorch 和 torchvision 版本与部署环境一致。
  • 安全性:避免直接加载不受信任的 .pth 文件,防止潜在的代码注入风险。

5. 参考资源

  • 官方文档
  • 模型保存/加载:https://pytorch.org/docs/stable/notes/serialization.html
  • TorchScript:https://pytorch.org/docs/stable/jit.html
  • ONNX:https://pytorch.org/docs/stable/onnx.html
  • TorchServe:https://pytorch.org/serve/
  • 教程:https://pytorch.org/tutorials/advanced/cpp_export.html
  • GitHub 仓库:https://github.com/pytorch/serve
  • 社区论坛:https://discuss.pytorch.org/

6. 进一步帮助

如果你需要针对特定场景的部署方案(如移动端、嵌入式设备、云服务),或需要优化推理性能(量化、TensorRT 集成)、调试部署问题,请提供更多细节,我可以为你提供定制化的代码或建议!

类似文章

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注