PyTorch torchvision 计算机视觉模块
torchvision
是 PyTorch 的一个子库,专门为计算机视觉任务提供工具,包括数据集、模型、图像变换和实用函数。以下是关于 torchvision
模块的参考手册,涵盖核心功能、常用组件和示例,力求简洁且全面,适合快速上手或深入了解。
1. torchvision
模块概述
torchvision
提供了以下主要功能:
- 数据集(
torchvision.datasets
):内置常见计算机视觉数据集(如 MNIST、CIFAR-10、ImageNet)。 - 模型(
torchvision.models
):预训练和未训练的深度学习模型(如 ResNet、VGG、Vision Transformer)。 - 图像变换(
torchvision.transforms
):数据预处理和增强工具。 - 实用工具:如图像加载、显示和格式转换。
torchvision
与 torch
和 torch.nn
无缝集成,广泛用于图像分类、目标检测、分割等任务。
2. 核心组件与使用方法
(1) 数据集(torchvision.datasets
)
torchvision.datasets
提供多个标准数据集,支持下载、预处理和与 DataLoader
配合使用。
- 常用数据集:
MNIST
:手写数字(28×28 灰度图像,10 类)。CIFAR10
/CIFAR100
:小尺寸彩色图像(32×32,10/100 类)。ImageNet
:大规模图像数据集(需自定义路径)。VOCSegmentation
/CocoDetection
:用于分割和目标检测。FashionMNIST
:时尚物品分类,类似 MNIST。- 使用示例:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义变换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载 MNIST
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
- 自定义数据集:
使用datasets.ImageFolder
加载本地图像文件夹,文件夹结构需为root/class_name/image.jpg
。
dataset = datasets.ImageFolder(root='path/to/images', transform=transform)
(2) 模型(torchvision.models
)
torchvision.models
提供预训练和未训练的模型,涵盖图像分类、检测、分割等任务。
- 分类模型:
resnet18
,resnet50
:残差网络。vgg16
,vgg19
:VGG 网络。vit_b_16
:Vision Transformer。mobilenet_v3_small
:轻量级移动端模型。- 预训练模型:
from torchvision.models import resnet18, ResNet18_Weights
# 加载预训练 ResNet18
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.eval()
- 自定义输出层:
修改最后一层以适应特定任务(如更改分类数)。
model.fc = nn.Linear(model.fc.in_features, num_classes) # 修改全连接层
- 检测与分割模型:
fasterrcnn_resnet50_fpn
:目标检测。maskrcnn_resnet50_fpn
:实例分割。
from torchvision.models.detection import fasterrcnn_resnet50_fpn
detector = fasterrcnn_resnet50_fpn(weights='DEFAULT')
(3) 图像变换(torchvision.transforms
)
transforms
提供图像预处理和数据增强功能,适用于 PIL.Image
、张量或 NumPy 数组。
- 常用变换:
ToTensor()
:将 PIL/NumPy 转为张量([C, H, W]
,[0, 1]
)。Normalize(mean, std)
:标准化(如 ImageNet 的[0.485, 0.456, 0.406]
,[0.229, 0.224, 0.225]
)。Resize(size)
:调整图像大小。RandomHorizontalFlip(p=0.5)
:随机水平翻转。RandomResizedCrop(size)
:随机裁剪并调整大小。ColorJitter(brightness, contrast, saturation, hue)
:颜色调整。- 组合变换:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
- 函数式变换(
torchvision.transforms.functional
):
提供更细粒度的控制,如手动调整亮度:
from torchvision.transforms import functional as F
image = F.adjust_brightness(image, brightness_factor=1.2)
(4) 实用工具
- 图像加载与保存:
from torchvision.io import read_image, write_image
image = read_image('image.jpg') # 加载为张量
write_image(image, 'output.jpg') # 保存张量为图像
- 可视化:
使用torchvision.utils.make_grid
拼接图像网格:
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
images, _ = next(iter(train_loader))
grid = make_grid(images)
plt.imshow(grid.permute(1, 2, 0).numpy())
plt.show()
3. 完整示例:图像分类
以下是一个使用 torchvision
进行图像分类的完整示例:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
# 超参数
num_classes = 10
batch_size = 32
epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据变换
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
# 加载预训练模型
model = models.resnet18(weights='DEFAULT')
model.fc = nn.Linear(model.fc.in_features, num_classes) # 修改分类层
model = model.to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练
model.train()
for epoch in range(epochs):
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}')
# 测试
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
4. 进阶用法
- 迁移学习:
- 冻结预训练层的参数,仅训练新添加的层:
python for param in model.parameters(): param.requires_grad = False model.fc = nn.Linear(model.fc.in_features, num_classes)
- 自定义数据集:
继承torch.utils.data.Dataset
处理非标准数据集:
class CustomImageDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_paths = [os.path.join(img_dir, f) for f in os.listdir(img_dir)]
self.transform = transform
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
image = Image.open(self.img_paths[idx])
label = ... # 提取标签逻辑
if self.transform:
image = self.transform(image)
return image, label
- 目标检测与分割:
使用torchvision.models.detection
的模型(如 Faster R-CNN):
from torchvision.models.detection import fasterrcnn_resnet50_fpn
model = fasterrcnn_resnet50_fpn(weights='DEFAULT').to(device)
model.eval()
predictions = model([image_tensor])
- 混合精度训练:
使用torch.cuda.amp
加速训练:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
outputs = model(images)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
5. 常见问题与注意事项
- 数据格式:确保图像张量为
[batch_size, channels, height, width]
,范围[0, 1]
(ToTensor
后)。 - 预训练权重:使用
weights='DEFAULT'
或具体版本(如ResNet18_Weights.IMAGENET1K_V1
)加载最新或指定权重。 - 内存管理:大数据集使用
num_workers
和pin_memory
加速DataLoader
:
DataLoader(dataset, batch_size=32, num_workers=4, pin_memory=True)
- 归一化:预训练模型通常需要 ImageNet 的均值和标准差进行归一化。
- 版本兼容性:
torchvision
随 PyTorch 更新,检查版本兼容性(如torchvision>=0.13
支持新权重 API)。
6. 参考资源
- 官方文档:
torchvision
:https://pytorch.org/vision/stable/index.htmldatasets
:https://pytorch.org/vision/stable/datasets.htmlmodels
:https://pytorch.org/vision/stable/models.htmltransforms
:https://pytorch.org/vision/stable/transforms.html- 教程:https://pytorch.org/tutorials/beginner/basics/vision_tutorial.html
- GitHub 仓库:https://github.com/pytorch/vision
- 社区论坛:https://discuss.pytorch.org/
7. 进一步帮助
如果你需要特定任务的实现(如目标检测、图像分割)、优化 torchvision
使用(如数据增强策略、模型微调),或调试代码,请提供更多细节,我可以为你提供定制化的代码或建议!