PyTorch Dataset Loading Guide
In PyTorch, data sets are typically loaded and processed using torch.utils.data.Dataset and torch.utils.data.DataLoader.
Firstly, create a custom dataset class that inherits from torch.utils.data.Dataset and implement the methods __len__ and __getitem__. In the __getitem__ method, load and preprocess data based on the index.
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
# 进行数据预处理
return sample
Next, instantiate a custom dataset class and create a data loader using torch.utils.data.DataLoader, specifying the batch size and whether the data should be shuffled.
data = [...] # 数据集
dataset = CustomDataset(data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
Finally, you can access the data in the dataset through an iterative data loader.
for batch in dataloader:
# 处理批量数据
pass