PyTorchでデータセットをロードして処理する方法は何ですか?

PyTorchでは、通常torch.utils.data.Datasetとtorch.utils.data.DataLoaderを使用してデータセットをロードして処理します。

最初に、torch.utils.data.Datasetを継承したカスタムデータセットクラスを作成し、__len__と__getitem__メソッドを実装します。__getitem__メソッドでは、インデックスに基づいてデータを読み込みおよび前処理することができます。

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        # 进行数据预处理
        return sample

その後、カスタムデータセットクラスをインスタンス化し、torch.utils.data.DataLoaderを使用してデータローダーを作成し、バッチサイズとデータのシャッフルを指定します。

data = [...]  # 数据集

dataset = CustomDataset(data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

最後に、データセット内のデータにアクセスするためにデータローダーを反復処理することができます。

for batch in dataloader:
    # 处理批量数据
    pass
bannerAds