PyTorchでデータセットをカスタマイズする方法は何ですか?

PyTorchでは、torch.utils.data.Datasetクラスを継承してデータセットをカスタマイズすることができます。カスタムデータセットは、__len__と__getitem__の2つのメソッドを実装する必要があります。

__len__メソッドは、データセットのサイズ、つまりサンプル数を返します。__getitem__メソッドは、指定されたインデックスに対応するサンプルを返します。

以下には、簡単なデータセットをカスタマイズする方法の例が示されています。

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = self.data[index]
        # 在这里对样本进行处理,例如进行预处理或转换
        return sample

上記の例では、CustomDatasetクラスはデータパラメーターを受け取ります。このパラメーターは、すべてのサンプルを含むリストまたは配列です。 __len__メソッドはデータセットのサイズを返し、__getitem__メソッドは指定されたインデックスに基づいて対応するサンプルを返します。

自作のデータセットを使用する場合、torch.utils.data.DataLoaderを使ってモデルと一緒に批量処理や繰り返し学習を行うことができます。

# 创建自定义数据集
data = [...]
dataset = CustomDataset(data)

# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

# 迭代数据加载器
for batch in dataloader:
    # 在这里进行模型训练或推断

最初に、カスタムデータセットdatasetを作成し、その後、torch.utils.data.DataLoaderを使用してデータローダーdataloaderを作成しました。batch_sizeパラメーターは各バッチのサンプル数を指定し、shuffle=Trueパラメーターはデータをランダムにシャッフルすることを示しています。

最後に、各バッチのサンプルを取得してモデルのトレーニングや推論に使用するために、dataloaderを反復処理できます。

bannerAds