PyTorchのデータセットの読み込み方法は何ですか?

PyTorchでは、torchvision.datasetsモジュールを使用して一般的なデータセットをロードすることができます。このモジュールは、以下の一般的なデータセットをサポートしています。

  1. MNISTは、手書き数字データセットです。
  2. FashionMNISTは、ファッションアイテムのデータセットです。
  3. CIFAR10/CIFAR100は、10/100のクラスを持つカラー画像データセットです。
  4. ImageNetは画像分類に使用される大規模データベースです。
  5. COCO:Object detection、image segmentation、image annotation用のデータセット。

データセットを読み込む一般的な手順は以下の通りです:

  1. モジュールをインポートしてください。
from torchvision import datasets
  1. データセットの変換の定義(オプション):
from torchvision import transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

画像をテンソルに変換し、正規化処理を行う。

  1. データセットを読み込む:
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

rootパラメーターは、データセットのダウンロードと保存先のパスを指定します。trainパラメーターはトレーニングセットを読み込むかテストセットを読み込むかを示し、transformパラメーターはデータセットへの変換を指定します。downloadパラメーターはデータセットをダウンロードするかどうかを示します(初回実行時のみダウンロードが必要)。

  1. データローダーを作成する。
from torch.utils.data import DataLoader

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

ここでのbatch_sizeパラメータは、各バッチのサンプル数を指定し、shuffleパラメータはデータをランダムにシャッフルするかどうかを示します。

上記の手順に従うことで、PyTorchのデータセットを読み込んでトレーニングやテストに使用することができます。

bannerAds