PyTorch torchvision 计算机视觉模块

torchvision 是 PyTorch 的一个子库,专门为计算机视觉任务提供工具,包括数据集、模型、图像变换和实用函数。以下是关于 torchvision 模块的参考手册,涵盖核心功能、常用组件和示例,力求简洁且全面,适合快速上手或深入了解。


1. torchvision 模块概述

torchvision 提供了以下主要功能:

  • 数据集(torchvision.datasets:内置常见计算机视觉数据集(如 MNIST、CIFAR-10、ImageNet)。
  • 模型(torchvision.models:预训练和未训练的深度学习模型(如 ResNet、VGG、Vision Transformer)。
  • 图像变换(torchvision.transforms:数据预处理和增强工具。
  • 实用工具:如图像加载、显示和格式转换。

torchvisiontorchtorch.nn 无缝集成,广泛用于图像分类、目标检测、分割等任务。


2. 核心组件与使用方法

(1) 数据集(torchvision.datasets

torchvision.datasets 提供多个标准数据集,支持下载、预处理和与 DataLoader 配合使用。

  • 常用数据集
  • MNIST:手写数字(28×28 灰度图像,10 类)。
  • CIFAR10 / CIFAR100:小尺寸彩色图像(32×32,10/100 类)。
  • ImageNet:大规模图像数据集(需自定义路径)。
  • VOCSegmentation / CocoDetection:用于分割和目标检测。
  • FashionMNIST:时尚物品分类,类似 MNIST。
  • 使用示例
  from torchvision import datasets, transforms
  from torch.utils.data import DataLoader

  # 定义变换
  transform = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize((0.5,), (0.5,))
  ])

  # 加载 MNIST
  train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

  # 创建 DataLoader
  train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
  • 自定义数据集
    使用 datasets.ImageFolder 加载本地图像文件夹,文件夹结构需为 root/class_name/image.jpg
  dataset = datasets.ImageFolder(root='path/to/images', transform=transform)

(2) 模型(torchvision.models

torchvision.models 提供预训练和未训练的模型,涵盖图像分类、检测、分割等任务。

  • 分类模型
  • resnet18, resnet50:残差网络。
  • vgg16, vgg19:VGG 网络。
  • vit_b_16:Vision Transformer。
  • mobilenet_v3_small:轻量级移动端模型。
  • 预训练模型
  from torchvision.models import resnet18, ResNet18_Weights

  # 加载预训练 ResNet18
  model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
  model.eval()
  • 自定义输出层
    修改最后一层以适应特定任务(如更改分类数)。
  model.fc = nn.Linear(model.fc.in_features, num_classes)  # 修改全连接层
  • 检测与分割模型
  • fasterrcnn_resnet50_fpn:目标检测。
  • maskrcnn_resnet50_fpn:实例分割。
  from torchvision.models.detection import fasterrcnn_resnet50_fpn
  detector = fasterrcnn_resnet50_fpn(weights='DEFAULT')

(3) 图像变换(torchvision.transforms

transforms 提供图像预处理和数据增强功能,适用于 PIL.Image、张量或 NumPy 数组。

  • 常用变换
  • ToTensor():将 PIL/NumPy 转为张量([C, H, W], [0, 1])。
  • Normalize(mean, std):标准化(如 ImageNet 的 [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])。
  • Resize(size):调整图像大小。
  • RandomHorizontalFlip(p=0.5):随机水平翻转。
  • RandomResizedCrop(size):随机裁剪并调整大小。
  • ColorJitter(brightness, contrast, saturation, hue):颜色调整。
  • 组合变换
  transform = transforms.Compose([
      transforms.Resize((224, 224)),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
  • 函数式变换(torchvision.transforms.functional
    提供更细粒度的控制,如手动调整亮度:
  from torchvision.transforms import functional as F
  image = F.adjust_brightness(image, brightness_factor=1.2)

(4) 实用工具

  • 图像加载与保存
  from torchvision.io import read_image, write_image
  image = read_image('image.jpg')  # 加载为张量
  write_image(image, 'output.jpg')  # 保存张量为图像
  • 可视化
    使用 torchvision.utils.make_grid 拼接图像网格:
  from torchvision.utils import make_grid
  import matplotlib.pyplot as plt
  images, _ = next(iter(train_loader))
  grid = make_grid(images)
  plt.imshow(grid.permute(1, 2, 0).numpy())
  plt.show()

3. 完整示例:图像分类

以下是一个使用 torchvision 进行图像分类的完整示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# 超参数
num_classes = 10
batch_size = 32
epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据变换
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# 加载预训练模型
model = models.resnet18(weights='DEFAULT')
model.fc = nn.Linear(model.fc.in_features, num_classes)  # 修改分类层
model = model.to(device)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练
model.train()
for epoch in range(epochs):
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}')

# 测试
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')

4. 进阶用法

  • 迁移学习
  • 冻结预训练层的参数,仅训练新添加的层:
    python for param in model.parameters(): param.requires_grad = False model.fc = nn.Linear(model.fc.in_features, num_classes)
  • 自定义数据集
    继承 torch.utils.data.Dataset 处理非标准数据集:
  class CustomImageDataset(Dataset):
      def __init__(self, img_dir, transform=None):
          self.img_paths = [os.path.join(img_dir, f) for f in os.listdir(img_dir)]
          self.transform = transform

      def __len__(self):
          return len(self.img_paths)

      def __getitem__(self, idx):
          image = Image.open(self.img_paths[idx])
          label = ...  # 提取标签逻辑
          if self.transform:
              image = self.transform(image)
          return image, label
  • 目标检测与分割
    使用 torchvision.models.detection 的模型(如 Faster R-CNN):
  from torchvision.models.detection import fasterrcnn_resnet50_fpn
  model = fasterrcnn_resnet50_fpn(weights='DEFAULT').to(device)
  model.eval()
  predictions = model([image_tensor])
  • 混合精度训练
    使用 torch.cuda.amp 加速训练:
  from torch.cuda.amp import autocast, GradScaler
  scaler = GradScaler()
  with autocast():
      outputs = model(images)
      loss = criterion(outputs, labels)
  scaler.scale(loss).backward()
  scaler.step(optimizer)
  scaler.update()

5. 常见问题与注意事项

  • 数据格式:确保图像张量为 [batch_size, channels, height, width],范围 [0, 1]ToTensor 后)。
  • 预训练权重:使用 weights='DEFAULT' 或具体版本(如 ResNet18_Weights.IMAGENET1K_V1)加载最新或指定权重。
  • 内存管理:大数据集使用 num_workerspin_memory 加速 DataLoader
  DataLoader(dataset, batch_size=32, num_workers=4, pin_memory=True)
  • 归一化:预训练模型通常需要 ImageNet 的均值和标准差进行归一化。
  • 版本兼容性torchvision 随 PyTorch 更新,检查版本兼容性(如 torchvision>=0.13 支持新权重 API)。

6. 参考资源

  • 官方文档
  • torchvision:https://pytorch.org/vision/stable/index.html
  • datasets:https://pytorch.org/vision/stable/datasets.html
  • models:https://pytorch.org/vision/stable/models.html
  • transforms:https://pytorch.org/vision/stable/transforms.html
  • 教程:https://pytorch.org/tutorials/beginner/basics/vision_tutorial.html
  • GitHub 仓库:https://github.com/pytorch/vision
  • 社区论坛:https://discuss.pytorch.org/

7. 进一步帮助

如果你需要特定任务的实现(如目标检测、图像分割)、优化 torchvision 使用(如数据增强策略、模型微调),或调试代码,请提供更多细节,我可以为你提供定制化的代码或建议!

类似文章

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注