PyTorch DataLoader: Data Loading Guide

There are several main steps to loading data using DataLoader in PyTorch.

  1. To create a dataset object, you must first create an object that inherits from the torch.utils.data.Dataset class and implements the __len__ and __getitem__ methods. The __len__ method should return the size of the dataset, while the __getitem__ method should return the corresponding data sample based on the given index.
  2. Create a dataset instance: Using the dataset object created in step 1, create a dataset instance.
  3. Create a data loader: Use the torch.utils.data.DataLoader class to create a data loader by passing the dataset instance as a parameter. You can set parameters like batch_size and shuffle to control how the data is loaded.
  4. Iterate through the data loader: Use a for loop to iterate through the data loader, where each iteration will return a batch of data. You can then pass this data into the model for training.

The sample code is as follows:

import torch
from torch.utils.data import Dataset, DataLoader

# 创建数据集对象
class MyDataset(Dataset):
    def __init__(self):
        self.data = [1, 2, 3, 4, 5]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# 创建数据集实例
dataset = MyDataset()

# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 遍历数据加载器
for batch_data in dataloader:
    print(batch_data)

In the example above, a simple dataset object called MyDataset was created first, followed by the creation of a dataset instance based on that dataset object. A data loader named dataloader was then created using the DataLoader class, with a batch size of 2 and shuffle set to True. Finally, iterating through the data loader with a for loop will return a batch of data with a size of 2 in each iteration.

bannerAds