PyTorch 模型保存和加载
在 PyTorch 中,模型的保存和加载是模型训练、推理和部署的核心步骤,通常通过 torch.save
和 torch.load
实现。以下是关于 PyTorch 模型保存和加载的详细指南,涵盖方法、注意事项和示例,力求简洁且实用。
1. 保存和加载模型的两种主要方式
PyTorch 提供了两种常用的保存和加载模型的方法:
- 保存/加载
state_dict
(推荐):仅保存模型的参数(权重和偏置),更轻量、灵活。 - 保存/加载整个模型:保存模型结构和参数,文件较大,兼容性稍差。
2. 保存和加载 state_dict
state_dict
是一个 Python 字典,存储模型的权重和偏置,适用于大多数场景。
(1) 保存 state_dict
import torch
import torch.nn as nn
# 示例模型
model = nn.Linear(10, 2)
# 保存模型参数
torch.save(model.state_dict(), 'model_weights.pth')
(2) 加载 state_dict
加载时需先创建相同结构的模型实例,然后加载参数。
# 创建相同结构的模型
model = nn.Linear(10, 2)
# 加载参数
model.load_state_dict(torch.load('model_weights.pth'))
# 切换到推理模式
model.eval()
(3) 注意事项
- 模型结构一致:加载
state_dict
前,必须确保模型结构与保存时完全相同。 - 推理模式:推理时使用
model.eval()
禁用 Dropout 和 BatchNorm 的训练行为。 - 设备兼容性:如果模型在 GPU 上训练,加载到 CPU 环境需指定
map_location
:
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))
3. 保存和加载整个模型
保存整个模型包括模型结构和参数,但文件较大,加载时可能需要特定的 PyTorch 版本。
(1) 保存整个模型
# 保存整个模型
torch.save(model, 'model.pth')
(2) 加载整个模型
# 加载整个模型
model = torch.load('model.pth')
model.eval()
(3) 注意事项
- 兼容性问题:保存整个模型可能因 PyTorch 版本或自定义层定义而导致加载失败。
- 推荐场景:仅在开发环境快速测试时使用,生产环境建议保存
state_dict
。
4. 保存和加载优化器状态
优化器状态(如动量、学习率)可以通过 optimizer.state_dict()
保存,适用于断点续训。
import torch.optim as optim
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 保存模型和优化器状态
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}
torch.save(checkpoint, 'checkpoint.pth')
# 加载模型和优化器状态
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
5. 完整示例:训练、保存和加载
以下是一个完整的图像分类示例,展示如何保存和加载模型。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 超参数
num_classes = 10
batch_size = 32
epochs = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据准备
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 定义模型
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练
model.train()
for epoch in range(epochs):
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}')
# 保存模型
torch.save(model.state_dict(), 'mnist_model.pth')
# 加载模型进行推理
model = SimpleNet().to(device)
model.load_state_dict(torch.load('mnist_model.pth', map_location=device))
model.eval()
# 推理示例
with torch.no_grad():
test_image = torch.randn(1, 1, 28, 28).to(device)
output = model(test_image)
prediction = output.argmax(dim=1)
print(f'Predicted class: {prediction.item()}')
6. 进阶用法
- 保存最佳模型:
在训练过程中保存验证集上表现最好的模型:
best_loss = float('inf')
for epoch in range(epochs):
# 训练和验证逻辑
val_loss = validate(model, val_loader, criterion)
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), 'best_model.pth')
- TorchScript 保存:
为部署到非 Python 环境(如 C++),保存为 TorchScript 格式:
model.eval()
traced_model = torch.jit.trace(model, torch.randn(1, 1, 28, 28).to(device))
traced_model.save('mnist_model.pt')
- ONNX 导出:
为跨平台部署,导出为 ONNX 格式:
model.eval()
dummy_input = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(model, dummy_input, 'mnist_model.onnx', input_names=['input'], output_names=['output'])
- 断点续训:
保存和加载模型、优化器及训练状态:
# 保存
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, 'checkpoint.pth')
# 加载
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
7. 常见问题与注意事项
- 设备一致性:
- 如果模型在 GPU 上训练,加载到 CPU 需使用
map_location='cpu'
。 - 跨设备加载可能导致错误,检查
model.to(device)
和data.to(device)
。 - 推理模式:
- 推理时必须调用
model.eval()
和torch.no_grad()
,以禁用 Dropout/BatchNorm 和节省内存。 - 版本兼容性:
- 不同 PyTorch 版本可能导致
state_dict
或模型加载失败,建议保存时记录版本号。 - 使用
torch.jit
或 ONNX 提高跨环境兼容性。 - 文件安全性:
- 加载
.pth
文件可能存在代码注入风险,仅从可信来源加载。 - 存储优化:
state_dict
比整个模型占用空间小,适合生产环境。- 使用
torch.save(..., _use_new_zipfile_serialization=True)
压缩文件。
8. 参考资源
- 官方文档:
- 保存和加载:https://pytorch.org/docs/stable/notes/serialization.html
- TorchScript:https://pytorch.org/docs/stable/jit.html
- ONNX:https://pytorch.org/docs/stable/onnx.html
- 教程:https://pytorch.org/tutorials/beginner/saving_loading_models.html
- 社区论坛:https://discuss.pytorch.org/
9. 进一步帮助
如果你需要特定场景的保存/加载方案(例如分布式训练、跨平台部署、或优化存储),或遇到加载错误、性能问题,请提供更多细节,我可以为你提供定制化的代码或解决方案!