PyTorch 实例 – 图像分类项目

以下是一个完整的 PyTorch 图像分类项目示例,使用 torchvision 的 CIFAR-10 数据集和 ResNet18 模型进行训练、验证和测试。代码包含数据预处理、模型定义、训练循环、验证和推理,适合初学者和需要快速上手的开发者。示例代码清晰、注释详细,并包括保存/加载模型的步骤。


项目概述

  • 任务:对 CIFAR-10 数据集进行图像分类(10 类:飞机、汽车、鸟等)。
  • 数据集:CIFAR-10(32×32 彩色图像,50,000 训练样本,10,000 测试样本)。
  • 模型:ResNet18(支持预训练或从头训练)。
  • 功能:数据加载、预处理、模型训练、验证、测试、模型保存和推理。
  • 环境:PyTorch、torchvision,运行在 CPU 或 GPU 上。

完整代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets, transforms, models
import numpy as np

# 设置随机种子以确保可重复性
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# 超参数
num_classes = 10
batch_size = 64
epochs = 10
learning_rate = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# 1. 数据预处理
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整为 ResNet 输入尺寸
    transforms.RandomHorizontalFlip(),  # 数据增强
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet 标准化
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 2. 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

# 划分训练集和验证集
num_train = len(train_dataset)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(0.8 * num_train)  # 80% 训练,20% 验证
train_idx, val_idx = indices[:split], indices[split:]

train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2)
val_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=val_sampler, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# 3. 定义模型
model = models.resnet18(weights=None)  # 从头训练(可改为 weights='DEFAULT' 使用预训练)
model.fc = nn.Linear(model.fc.in_features, num_classes)  # 修改分类层
model = model.to(device)

# 4. 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # 每 5 个 epoch 降低学习率

# 5. 训练和验证函数
def train(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return total_loss / len(loader), 100 * correct / total

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return total_loss / len(loader), 100 * correct / total

# 6. 训练循环
best_val_acc = 0
for epoch in range(epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    scheduler.step()

    print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')

    # 保存最佳模型
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'Saved best model with Val Acc: {val_acc:.2f}%')

# 7. 测试模型
model.load_state_dict(torch.load('best_model.pth', map_location=device))
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
test_acc = 100 * correct / total
print(f'Test Accuracy: {test_acc:.2f}%')

# 8. 推理示例
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']
model.eval()
with torch.no_grad():
    sample_image, sample_label = next(iter(test_loader))
    sample_image = sample_image[0:1].to(device)  # 取一个样本
    output = model(sample_image)
    _, predicted = torch.max(output, 1)
    print(f'Predicted: {class_names[predicted.item()]}, Actual: {class_names[sample_label[0]]}')

代码说明

  1. 数据预处理
  • 使用 transforms 进行数据增强(训练集:随机翻转、旋转)和标准化(ImageNet 均值/标准差)。
  • CIFAR-10 数据集自动下载,训练集划分为 80% 训练和 20% 验证。
  1. 模型
  • 使用 ResNet18,修改全连接层以适配 10 类输出。
  • 支持从头训练(weights=None)或使用预训练权重(weights='DEFAULT')。
  1. 训练与验证
  • 训练循环计算损失和准确率,验证集用于监控模型性能。
  • 使用 Adam 优化器和 StepLR 调度器动态调整学习率。
  • 保存验证集上表现最佳的模型。
  1. 测试与推理
  • 加载最佳模型进行测试,计算整体准确率。
  • 展示单张图像的预测结果。
  1. 运行环境
  • 支持 CPU 和 GPU(自动检测)。
  • 使用 num_workers=2 加速数据加载(Windows 用户可能需设为 0)。

输出示例

运行代码可能得到类似以下输出(具体数值因随机性和硬件不同而异):

Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Epoch 1/10, Train Loss: 1.2345, Train Acc: 55.32%, Val Loss: 1.0123, Val Acc: 65.43%
Saved best model with Val Acc: 65.43%
Epoch 2/10, Train Loss: 0.9876, Train Acc: 67.89%, Val Loss: 0.8765, Val Acc: 70.12%
Saved best model with Val Acc: 70.12%
...
Test Accuracy: 71.23%
Predicted: cat, Actual: cat

运行要求

  • 依赖torch, torchvision, numpy, matplotlib(可选,用于可视化)。
  pip install torch torchvision numpy
  • 硬件:支持 CPU 或 GPU,GPU 加速需要 CUDA 兼容的 NVIDIA 显卡。
  • 数据集:CIFAR-10 会自动下载到 ./data 目录(约 170MB)。

进阶建议

  1. 预训练模型
  • model = models.resnet18(weights=None) 改为 weights='DEFAULT',利用 ImageNet 预训练权重加速收敛。
  • 冻结卷积层,仅训练全连接层:
    python for param in model.parameters(): param.requires_grad = False model.fc = nn.Linear(model.fc.in_features, num_classes)
  1. 数据增强
  • 添加更多增强(如 transforms.ColorJitter()transforms.RandomCrop())。
  • 使用 albumentations 库实现复杂增强:
    python import albumentations as A from albumentations.pytorch import ToTensorV2 transform = A.Compose([ A.Resize(224, 224), A.HorizontalFlip(p=0.5), ToTensorV2() ])
  1. 优化性能
  • 使用混合精度训练:
    python from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() with autocast(): outputs = model(images) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
  • 增加 num_workers 或使用 pin_memory=True 加速数据加载。
  1. 模型部署
  • 导出为 TorchScript:
    python model.eval() traced_model = torch.jit.trace(model, torch.randn(1, 3, 224, 224).to(device)) traced_model.save('resnet18_cifar10.pt')
  • 导出为 ONNX:
    python torch.onnx.export(model, torch.randn(1, 3, 224, 224).to(device), 'resnet18_cifar10.onnx')
  1. 可视化
  • 使用 torchvision.utils.make_grid 可视化预测结果:
    python from torchvision.utils import make_grid import matplotlib.pyplot as plt images, labels = next(iter(test_loader)) grid = make_grid(images[:8]) plt.imshow(grid.permute(1, 2, 0).numpy()) plt.show()

常见问题与注意事项

  • 内存不足:降低 batch_size 或使用 torch.cuda.amp 进行混合精度训练。
  • 数据加载慢:在 Linux/Mac 上增加 num_workers,Windows 用户可能需设为 0。
  • 过拟合:增加数据增强、Dropout 或正则化(如 weight_decay=1e-4)。
  • 预训练模型:确保输入图像经过 ImageNet 标准化(均值/标准差)。
  • 版本兼容性:使用 torchvision>=0.13 支持新权重 API(如 ResNet18_Weights)。

参考资源

  • 官方文档
  • torchvision.datasets:https://pytorch.org/vision/stable/datasets.html
  • torchvision.models:https://pytorch.org/vision/stable/models.html
  • torchvision.transforms:https://pytorch.org/vision/stable/transforms.html
  • 教程:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
  • 社区论坛:https://discuss.pytorch.org/

进一步帮助

如果你需要扩展此项目(如添加目标检测、分割任务、或部署到生产环境),优化性能(量化、TensorRT),或调试特定问题,请提供更多细节,我可以为你提供定制化的代码或建议!

类似文章

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注