PyTorch 数据集

在 PyTorch 中,Dataset 是一个用于处理数据的核心抽象类,位于 torch.utils.data 模块中。它为数据加载和预处理提供了灵活的接口,通常与 DataLoader 配合使用,以便高效地加载和批处理数据。以下是对 PyTorch 数据集的详细说明:

1. PyTorch Dataset 类的核心

PyTorch 提供了 torch.utils.data.Dataset 基类,你可以通过继承它来创建自定义数据集。Dataset 类需要实现两个主要方法:

  • __len__(self):返回数据集的总样本数。
  • __getitem__(self, index):根据给定的索引返回一个样本(通常包括数据和标签)。

示例:

from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data  # 假设 data 是输入数据
        self.labels = labels  # 假设 labels 是对应的标签

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

2. 内置数据集

PyTorch 的 torchvision.datasets 模块提供了许多常用的预定义数据集,例如:

  • MNIST:手写数字数据集。
  • CIFAR-10/CIFAR-100:小型图像分类数据集。
  • ImageNet:需要自定义加载,但 torchvision 提供支持。
  • FashionMNIST:类似于 MNIST,但包含时尚物品图像。

使用示例:

from torchvision import datasets, transforms

# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

3. DataLoader 配合使用

Dataset 通常与 DataLoader 一起使用,DataLoader 提供批量加载、数据打乱(shuffle)、多线程加载等功能。

示例:

from torch.utils.data import DataLoader

# 创建 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True, num_workers=2)

# 迭代数据
for images, labels in train_loader:
    # 训练代码
    print(images.shape, labels.shape)  # 示例输出: torch.Size([64, 1, 28, 28]) torch.Size([64])

4. 自定义数据集

对于自定义数据(例如本地图像、CSV 文件等),你需要继承 Dataset 类并实现数据加载逻辑。以下是一个处理图像文件夹的示例:

import os
from PIL import Image
from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_names = os.listdir(img_dir)
        self.labels = [int(name.split('_')[0]) for name in self.img_names]  # 假设文件名包含标签

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.img_names[index])
        image = Image.open(img_path)
        label = self.labels[index]

        if self.transform:
            image = self.transform(image)

        return image, label

5. 数据增强与预处理

torchvision.transforms 提供了丰富的图像预处理和数据增强方法,例如:

  • ToTensor():将 PIL 图像或 NumPy 数组转换为 PyTorch 张量。
  • Normalize(mean, std):标准化图像。
  • RandomCrop, RandomHorizontalFlip:数据增强操作。

示例:

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

6. 常见问题与注意事项

  • 内存管理:对于大数据集,避免一次性加载所有数据到内存,可在 __getitem__ 中动态加载。
  • 多线程加载:设置 DataLoadernum_workers 参数以加速数据加载,但 Windows 系统可能需要设置 num_workers=0 或使用 if __name__ == '__main__':
  • 数据格式:确保 __getitem__ 返回的数据格式与模型输入一致(例如,图像张量的形状为 [C, H, W])。
  • 数据集划分:可以使用 torch.utils.data.Subsetrandom_split 进行训练/验证集划分。

示例(数据集划分):

from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

7. 进阶用法

  • 分布式训练:结合 torch.utils.data.distributed.DistributedSampler 支持多 GPU 训练。
  • 自定义采样:通过 torch.utils.data.WeightedRandomSampler 实现加权采样,处理类别不平衡。
  • 在线数据增强:在 __getitem__ 中动态应用增强操作。

8. 参考资源

  • PyTorch 官方文档:https://pytorch.org/docs/stable/data.html
  • torchvision.datasets:https://pytorch.org/vision/stable/datasets.html
  • 社区教程:如 PyTorch 官方教程或 GitHub 上的开源项目。

如果你有具体的数据集需求(例如处理特定格式的数据、优化加载速度等),可以进一步提供细节,我会为你量身定制解决方案!

类似文章

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注