torchvision 是 PyTorch 的一个子库,专门为计算机视觉任务提供工具,包括数据集、模型、图像变换和实用函数。以下是关于 torchvision 模块的参考手册,涵盖核心功能、常用组件和示例,力求简洁且全面,适合快速上手或深入了解。
1. torchvision 模块概述
torchvision 提供了以下主要功能:
- 数据集(
torchvision.datasets):内置常见计算机视觉数据集(如 MNIST、CIFAR-10、ImageNet)。 - 模型(
torchvision.models):预训练和未训练的深度学习模型(如 ResNet、VGG、Vision Transformer)。 - 图像变换(
torchvision.transforms):数据预处理和增强工具。 - 实用工具:如图像加载、显示和格式转换。
torchvision 与 torch 和 torch.nn 无缝集成,广泛用于图像分类、目标检测、分割等任务。
2. 核心组件与使用方法
(1) 数据集(torchvision.datasets)
torchvision.datasets 提供多个标准数据集,支持下载、预处理和与 DataLoader 配合使用。
- 常用数据集:
MNIST:手写数字(28×28 灰度图像,10 类)。CIFAR10/CIFAR100:小尺寸彩色图像(32×32,10/100 类)。ImageNet:大规模图像数据集(需自定义路径)。VOCSegmentation/CocoDetection:用于分割和目标检测。FashionMNIST:时尚物品分类,类似 MNIST。- 使用示例:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义变换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载 MNIST
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 创建 DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
- 自定义数据集:
使用datasets.ImageFolder加载本地图像文件夹,文件夹结构需为root/class_name/image.jpg。
dataset = datasets.ImageFolder(root='path/to/images', transform=transform)
(2) 模型(torchvision.models)
torchvision.models 提供预训练和未训练的模型,涵盖图像分类、检测、分割等任务。
- 分类模型:
resnet18,resnet50:残差网络。vgg16,vgg19:VGG 网络。vit_b_16:Vision Transformer。mobilenet_v3_small:轻量级移动端模型。- 预训练模型:
from torchvision.models import resnet18, ResNet18_Weights
# 加载预训练 ResNet18
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.eval()
- 自定义输出层:
修改最后一层以适应特定任务(如更改分类数)。
model.fc = nn.Linear(model.fc.in_features, num_classes) # 修改全连接层
- 检测与分割模型:
fasterrcnn_resnet50_fpn:目标检测。maskrcnn_resnet50_fpn:实例分割。
from torchvision.models.detection import fasterrcnn_resnet50_fpn
detector = fasterrcnn_resnet50_fpn(weights='DEFAULT')
(3) 图像变换(torchvision.transforms)
transforms 提供图像预处理和数据增强功能,适用于 PIL.Image、张量或 NumPy 数组。
- 常用变换:
ToTensor():将 PIL/NumPy 转为张量([C, H, W],[0, 1])。Normalize(mean, std):标准化(如 ImageNet 的[0.485, 0.456, 0.406],[0.229, 0.224, 0.225])。Resize(size):调整图像大小。RandomHorizontalFlip(p=0.5):随机水平翻转。RandomResizedCrop(size):随机裁剪并调整大小。ColorJitter(brightness, contrast, saturation, hue):颜色调整。- 组合变换:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
- 函数式变换(
torchvision.transforms.functional):
提供更细粒度的控制,如手动调整亮度:
from torchvision.transforms import functional as F
image = F.adjust_brightness(image, brightness_factor=1.2)
(4) 实用工具
- 图像加载与保存:
from torchvision.io import read_image, write_image
image = read_image('image.jpg') # 加载为张量
write_image(image, 'output.jpg') # 保存张量为图像
- 可视化:
使用torchvision.utils.make_grid拼接图像网格:
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
images, _ = next(iter(train_loader))
grid = make_grid(images)
plt.imshow(grid.permute(1, 2, 0).numpy())
plt.show()
3. 完整示例:图像分类
以下是一个使用 torchvision 进行图像分类的完整示例:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
# 超参数
num_classes = 10
batch_size = 32
epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据变换
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
# 加载预训练模型
model = models.resnet18(weights='DEFAULT')
model.fc = nn.Linear(model.fc.in_features, num_classes) # 修改分类层
model = model.to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练
model.train()
for epoch in range(epochs):
total_loss = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}')
# 测试
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Test Accuracy: {100 * correct / total:.2f}%')
4. 进阶用法
- 迁移学习:
- 冻结预训练层的参数,仅训练新添加的层:
python for param in model.parameters(): param.requires_grad = False model.fc = nn.Linear(model.fc.in_features, num_classes) - 自定义数据集:
继承torch.utils.data.Dataset处理非标准数据集:
class CustomImageDataset(Dataset):
def __init__(self, img_dir, transform=None):
self.img_paths = [os.path.join(img_dir, f) for f in os.listdir(img_dir)]
self.transform = transform
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
image = Image.open(self.img_paths[idx])
label = ... # 提取标签逻辑
if self.transform:
image = self.transform(image)
return image, label
- 目标检测与分割:
使用torchvision.models.detection的模型(如 Faster R-CNN):
from torchvision.models.detection import fasterrcnn_resnet50_fpn
model = fasterrcnn_resnet50_fpn(weights='DEFAULT').to(device)
model.eval()
predictions = model([image_tensor])
- 混合精度训练:
使用torch.cuda.amp加速训练:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
outputs = model(images)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
5. 常见问题与注意事项
- 数据格式:确保图像张量为
[batch_size, channels, height, width],范围[0, 1](ToTensor后)。 - 预训练权重:使用
weights='DEFAULT'或具体版本(如ResNet18_Weights.IMAGENET1K_V1)加载最新或指定权重。 - 内存管理:大数据集使用
num_workers和pin_memory加速DataLoader:
DataLoader(dataset, batch_size=32, num_workers=4, pin_memory=True)
- 归一化:预训练模型通常需要 ImageNet 的均值和标准差进行归一化。
- 版本兼容性:
torchvision随 PyTorch 更新,检查版本兼容性(如torchvision>=0.13支持新权重 API)。
6. 参考资源
- 官方文档:
torchvision:https://pytorch.org/vision/stable/index.htmldatasets:https://pytorch.org/vision/stable/datasets.htmlmodels:https://pytorch.org/vision/stable/models.htmltransforms:https://pytorch.org/vision/stable/transforms.html- 教程:https://pytorch.org/tutorials/beginner/basics/vision_tutorial.html
- GitHub 仓库:https://github.com/pytorch/vision
- 社区论坛:https://discuss.pytorch.org/
7. 进一步帮助
如果你需要特定任务的实现(如目标检测、图像分割)、优化 torchvision 使用(如数据增强策略、模型微调),或调试代码,请提供更多细节,我可以为你提供定制化的代码或建议!