PyTorch 实例 – 图像分类项目
以下是一个完整的 PyTorch 图像分类项目示例,使用 torchvision
的 CIFAR-10 数据集和 ResNet18 模型进行训练、验证和测试。代码包含数据预处理、模型定义、训练循环、验证和推理,适合初学者和需要快速上手的开发者。示例代码清晰、注释详细,并包括保存/加载模型的步骤。
项目概述
- 任务:对 CIFAR-10 数据集进行图像分类(10 类:飞机、汽车、鸟等)。
- 数据集:CIFAR-10(32×32 彩色图像,50,000 训练样本,10,000 测试样本)。
- 模型:ResNet18(支持预训练或从头训练)。
- 功能:数据加载、预处理、模型训练、验证、测试、模型保存和推理。
- 环境:PyTorch、torchvision,运行在 CPU 或 GPU 上。
完整代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets, transforms, models
import numpy as np
# 设置随机种子以确保可重复性
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
# 超参数
num_classes = 10
batch_size = 64
epochs = 10
learning_rate = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
# 1. 数据预处理
train_transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整为 ResNet 输入尺寸
transforms.RandomHorizontalFlip(), # 数据增强
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准化
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 2. 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
# 划分训练集和验证集
num_train = len(train_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(0.8 * num_train) # 80% 训练,20% 验证
train_idx, val_idx = indices[:split], indices[split:]
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2)
val_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=val_sampler, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# 3. 定义模型
model = models.resnet18(weights=None) # 从头训练(可改为 weights='DEFAULT' 使用预训练)
model.fc = nn.Linear(model.fc.in_features, num_classes) # 修改分类层
model = model.to(device)
# 4. 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) # 每 5 个 epoch 降低学习率
# 5. 训练和验证函数
def train(model, loader, criterion, optimizer, device):
model.train()
total_loss = 0
correct = 0
total = 0
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return total_loss / len(loader), 100 * correct / total
def validate(model, loader, criterion, device):
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return total_loss / len(loader), 100 * correct / total
# 6. 训练循环
best_val_acc = 0
for epoch in range(epochs):
train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = validate(model, val_loader, criterion, device)
scheduler.step()
print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
# 保存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
print(f'Saved best model with Val Acc: {val_acc:.2f}%')
# 7. 测试模型
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_acc = 100 * correct / total
print(f'Test Accuracy: {test_acc:.2f}%')
# 8. 推理示例
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
model.eval()
with torch.no_grad():
sample_image, sample_label = next(iter(test_loader))
sample_image = sample_image[0:1].to(device) # 取一个样本
output = model(sample_image)
_, predicted = torch.max(output, 1)
print(f'Predicted: {class_names[predicted.item()]}, Actual: {class_names[sample_label[0]]}')
代码说明
- 数据预处理:
- 使用
transforms
进行数据增强(训练集:随机翻转、旋转)和标准化(ImageNet 均值/标准差)。 - CIFAR-10 数据集自动下载,训练集划分为 80% 训练和 20% 验证。
- 模型:
- 使用 ResNet18,修改全连接层以适配 10 类输出。
- 支持从头训练(
weights=None
)或使用预训练权重(weights='DEFAULT'
)。
- 训练与验证:
- 训练循环计算损失和准确率,验证集用于监控模型性能。
- 使用 Adam 优化器和 StepLR 调度器动态调整学习率。
- 保存验证集上表现最佳的模型。
- 测试与推理:
- 加载最佳模型进行测试,计算整体准确率。
- 展示单张图像的预测结果。
- 运行环境:
- 支持 CPU 和 GPU(自动检测)。
- 使用
num_workers=2
加速数据加载(Windows 用户可能需设为 0)。
输出示例
运行代码可能得到类似以下输出(具体数值因随机性和硬件不同而异):
Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Epoch 1/10, Train Loss: 1.2345, Train Acc: 55.32%, Val Loss: 1.0123, Val Acc: 65.43%
Saved best model with Val Acc: 65.43%
Epoch 2/10, Train Loss: 0.9876, Train Acc: 67.89%, Val Loss: 0.8765, Val Acc: 70.12%
Saved best model with Val Acc: 70.12%
...
Test Accuracy: 71.23%
Predicted: cat, Actual: cat
运行要求
- 依赖:
torch
,torchvision
,numpy
,matplotlib
(可选,用于可视化)。
pip install torch torchvision numpy
- 硬件:支持 CPU 或 GPU,GPU 加速需要 CUDA 兼容的 NVIDIA 显卡。
- 数据集:CIFAR-10 会自动下载到
./data
目录(约 170MB)。
进阶建议
- 预训练模型:
- 将
model = models.resnet18(weights=None)
改为weights='DEFAULT'
,利用 ImageNet 预训练权重加速收敛。 - 冻结卷积层,仅训练全连接层:
python for param in model.parameters(): param.requires_grad = False model.fc = nn.Linear(model.fc.in_features, num_classes)
- 数据增强:
- 添加更多增强(如
transforms.ColorJitter()
或transforms.RandomCrop()
)。 - 使用
albumentations
库实现复杂增强:python import albumentations as A from albumentations.pytorch import ToTensorV2 transform = A.Compose([ A.Resize(224, 224), A.HorizontalFlip(p=0.5), ToTensorV2() ])
- 优化性能:
- 使用混合精度训练:
python from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() with autocast(): outputs = model(images) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
- 增加
num_workers
或使用pin_memory=True
加速数据加载。
- 模型部署:
- 导出为 TorchScript:
python model.eval() traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224).to(device)) traced_model.save('resnet18_cifar10.pt')
- 导出为 ONNX:
python torch.onnx.export(model, torch.randn(1, 3, 224, 224).to(device), 'resnet18_cifar10.onnx')
- 可视化:
- 使用
torchvision.utils.make_grid
可视化预测结果:python from torchvision.utils import make_grid import matplotlib.pyplot as plt images, labels = next(iter(test_loader)) grid = make_grid(images[:8]) plt.imshow(grid.permute(1, 2, 0).numpy()) plt.show()
常见问题与注意事项
- 内存不足:降低
batch_size
或使用torch.cuda.amp
进行混合精度训练。 - 数据加载慢:在 Linux/Mac 上增加
num_workers
,Windows 用户可能需设为 0。 - 过拟合:增加数据增强、Dropout 或正则化(如
weight_decay=1e-4
)。 - 预训练模型:确保输入图像经过 ImageNet 标准化(均值/标准差)。
- 版本兼容性:使用
torchvision>=0.13
支持新权重 API(如ResNet18_Weights
)。
参考资源
- 官方文档:
torchvision.datasets
:https://pytorch.org/vision/stable/datasets.htmltorchvision.models
:https://pytorch.org/vision/stable/models.htmltorchvision.transforms
:https://pytorch.org/vision/stable/transforms.html- 教程:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
- 社区论坛:https://discuss.pytorch.org/
进一步帮助
如果你需要扩展此项目(如添加目标检测、分割任务、或部署到生产环境),优化性能(量化、TensorRT),或调试特定问题,请提供更多细节,我可以为你提供定制化的代码或建议!