PyTorchでデータセットをロードして処理する方法は何ですか?
PyTorchでは、通常torch.utils.data.Datasetとtorch.utils.data.DataLoaderを使用してデータセットをロードして処理します。
最初に、torch.utils.data.Datasetを継承したカスタムデータセットクラスを作成し、__len__と__getitem__メソッドを実装します。__getitem__メソッドでは、インデックスに基づいてデータを読み込みおよび前処理することができます。
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
# 进行数据预处理
return sample
その後、カスタムデータセットクラスをインスタンス化し、torch.utils.data.DataLoaderを使用してデータローダーを作成し、バッチサイズとデータのシャッフルを指定します。
data = [...] # 数据集
dataset = CustomDataset(data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
最後に、データセット内のデータにアクセスするためにデータローダーを反復処理することができます。
for batch in dataloader:
# 处理批量数据
pass