PyTorch データセット作成入門:基本と実践

PyTorchには、カスタムデータセットを作成するために使用できるDatasetというクラスが提供されています。データセットを作成するには、Datasetクラスを継承し、__len__と__getitem__の2つのメソッドを実装する必要があります。

__len__メソッドは、データセットのサイズ、つまりデータサンプルの数を返します。

__getitem__メソッドは、指定されたインデックスに対応するデータサンプルを返します。このメソッドでは、データファイルを読み取り、データを前処理し、モデルに必要な入力および出力データを返すことができます。

以下は、カスタムデータセットクラスを作成する方法を示す簡単な例です。

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        # 可以对数据进行预处理
        input_data = sample[:-1]
        target = sample[-1]
        return torch.tensor(input_data), torch.tensor(target)

上記の例では、CustomDatasetクラスはデータリストをパラメータとして受け取り、__len__と__getitem__メソッドを実装しています。__getitem__メソッドでは、データサンプルを入力データとターゲットデータに分割し、それに対応するテンソルを返します。

自作のデータセットクラスを作成したら、DataLoaderクラスを使用してデータをロードし、モデルのトレーニングを繰り返すことができます。

bannerAds