PyTorch Dataset Loading Guide

In PyTorch, data sets are typically loaded and processed using torch.utils.data.Dataset and torch.utils.data.DataLoader.

Firstly, create a custom dataset class that inherits from torch.utils.data.Dataset and implement the methods __len__ and __getitem__. In the __getitem__ method, load and preprocess data based on the index.

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        # 进行数据预处理
        return sample

Next, instantiate a custom dataset class and create a data loader using torch.utils.data.DataLoader, specifying the batch size and whether the data should be shuffled.

data = [...]  # 数据集

dataset = CustomDataset(data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

Finally, you can access the data in the dataset through an iterative data loader.

for batch in dataloader:
    # 处理批量数据
    pass
bannerAds