PyTorch データセット作成入門:基本と実践
PyTorchには、カスタムデータセットを作成するために使用できるDatasetというクラスが提供されています。データセットを作成するには、Datasetクラスを継承し、__len__と__getitem__の2つのメソッドを実装する必要があります。
__len__メソッドは、データセットのサイズ、つまりデータサンプルの数を返します。
__getitem__メソッドは、指定されたインデックスに対応するデータサンプルを返します。このメソッドでは、データファイルを読み取り、データを前処理し、モデルに必要な入力および出力データを返すことができます。
以下は、カスタムデータセットクラスを作成する方法を示す簡単な例です。
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# 可以对数据进行预处理
input_data = sample[:-1]
target = sample[-1]
return torch.tensor(input_data), torch.tensor(target)
上記の例では、CustomDatasetクラスはデータリストをパラメータとして受け取り、__len__と__getitem__メソッドを実装しています。__getitem__メソッドでは、データサンプルを入力データとターゲットデータに分割し、それに対応するテンソルを返します。
自作のデータセットクラスを作成したら、DataLoaderクラスを使用してデータをロードし、モデルのトレーニングを繰り返すことができます。