PyTorch 模型保存和加载

在 PyTorch 中,模型的保存和加载是模型训练、推理和部署的核心步骤,通常通过 torch.savetorch.load 实现。以下是关于 PyTorch 模型保存和加载的详细指南,涵盖方法、注意事项和示例,力求简洁且实用。


1. 保存和加载模型的两种主要方式

PyTorch 提供了两种常用的保存和加载模型的方法:

  • 保存/加载 state_dict(推荐):仅保存模型的参数(权重和偏置),更轻量、灵活。
  • 保存/加载整个模型:保存模型结构和参数,文件较大,兼容性稍差。

2. 保存和加载 state_dict

state_dict 是一个 Python 字典,存储模型的权重和偏置,适用于大多数场景。

(1) 保存 state_dict

import torch
import torch.nn as nn

# 示例模型
model = nn.Linear(10, 2)

# 保存模型参数
torch.save(model.state_dict(), 'model_weights.pth')

(2) 加载 state_dict

加载时需先创建相同结构的模型实例,然后加载参数。

# 创建相同结构的模型
model = nn.Linear(10, 2)

# 加载参数
model.load_state_dict(torch.load('model_weights.pth'))

# 切换到推理模式
model.eval()

(3) 注意事项

  • 模型结构一致:加载 state_dict 前,必须确保模型结构与保存时完全相同。
  • 推理模式:推理时使用 model.eval() 禁用 Dropout 和 BatchNorm 的训练行为。
  • 设备兼容性:如果模型在 GPU 上训练,加载到 CPU 环境需指定 map_location
  model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))

3. 保存和加载整个模型

保存整个模型包括模型结构和参数,但文件较大,加载时可能需要特定的 PyTorch 版本。

(1) 保存整个模型

# 保存整个模型
torch.save(model, 'model.pth')

(2) 加载整个模型

# 加载整个模型
model = torch.load('model.pth')
model.eval()

(3) 注意事项

  • 兼容性问题:保存整个模型可能因 PyTorch 版本或自定义层定义而导致加载失败。
  • 推荐场景:仅在开发环境快速测试时使用,生产环境建议保存 state_dict

4. 保存和加载优化器状态

优化器状态(如动量、学习率)可以通过 optimizer.state_dict() 保存,适用于断点续训。

import torch.optim as optim

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 保存模型和优化器状态
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': epoch
}
torch.save(checkpoint, 'checkpoint.pth')

# 加载模型和优化器状态
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']

5. 完整示例:训练、保存和加载

以下是一个完整的图像分类示例,展示如何保存和加载模型。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 超参数
num_classes = 10
batch_size = 32
epochs = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据准备
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 定义模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练
model.train()
for epoch in range(epochs):
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}')

# 保存模型
torch.save(model.state_dict(), 'mnist_model.pth')

# 加载模型进行推理
model = SimpleNet().to(device)
model.load_state_dict(torch.load('mnist_model.pth', map_location=device))
model.eval()

# 推理示例
with torch.no_grad():
    test_image = torch.randn(1, 1, 28, 28).to(device)
    output = model(test_image)
    prediction = output.argmax(dim=1)
    print(f'Predicted class: {prediction.item()}')

6. 进阶用法

  • 保存最佳模型
    在训练过程中保存验证集上表现最好的模型:
  best_loss = float('inf')
  for epoch in range(epochs):
      # 训练和验证逻辑
      val_loss = validate(model, val_loader, criterion)
      if val_loss < best_loss:
          best_loss = val_loss
          torch.save(model.state_dict(), 'best_model.pth')
  • TorchScript 保存
    为部署到非 Python 环境(如 C++),保存为 TorchScript 格式:
  model.eval()
  traced_model = torch.jit.trace(model, torch.randn(1, 1, 28, 28).to(device))
  traced_model.save('mnist_model.pt')
  • ONNX 导出
    为跨平台部署,导出为 ONNX 格式:
  model.eval()
  dummy_input = torch.randn(1, 1, 28, 28).to(device)
  torch.onnx.export(model, dummy_input, 'mnist_model.onnx', input_names=['input'], output_names=['output'])
  • 断点续训
    保存和加载模型、优化器及训练状态:
  # 保存
  torch.save({
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss
  }, 'checkpoint.pth')

  # 加载
  checkpoint = torch.load('checkpoint.pth')
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  start_epoch = checkpoint['epoch']

7. 常见问题与注意事项

  • 设备一致性
  • 如果模型在 GPU 上训练,加载到 CPU 需使用 map_location='cpu'
  • 跨设备加载可能导致错误,检查 model.to(device)data.to(device)
  • 推理模式
  • 推理时必须调用 model.eval()torch.no_grad(),以禁用 Dropout/BatchNorm 和节省内存。
  • 版本兼容性
  • 不同 PyTorch 版本可能导致 state_dict 或模型加载失败,建议保存时记录版本号。
  • 使用 torch.jit 或 ONNX 提高跨环境兼容性。
  • 文件安全性
  • 加载 .pth 文件可能存在代码注入风险,仅从可信来源加载。
  • 存储优化
  • state_dict 比整个模型占用空间小,适合生产环境。
  • 使用 torch.save(..., _use_new_zipfile_serialization=True) 压缩文件。

8. 参考资源

  • 官方文档
  • 保存和加载:https://pytorch.org/docs/stable/notes/serialization.html
  • TorchScript:https://pytorch.org/docs/stable/jit.html
  • ONNX:https://pytorch.org/docs/stable/onnx.html
  • 教程:https://pytorch.org/tutorials/beginner/saving_loading_models.html
  • 社区论坛:https://discuss.pytorch.org/

9. 进一步帮助

如果你需要特定场景的保存/加载方案(例如分布式训练、跨平台部署、或优化存储),或遇到加载错误、性能问题,请提供更多细节,我可以为你提供定制化的代码或解决方案!

类似文章

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注