PyTorchのデータセットの読み込み方法は何ですか?
PyTorchでは、torchvision.datasetsモジュールを使用して一般的なデータセットをロードすることができます。このモジュールは、以下の一般的なデータセットをサポートしています。
- MNISTは、手書き数字データセットです。
- FashionMNISTは、ファッションアイテムのデータセットです。
- CIFAR10/CIFAR100は、10/100のクラスを持つカラー画像データセットです。
- ImageNetは画像分類に使用される大規模データベースです。
- COCO:Object detection、image segmentation、image annotation用のデータセット。
データセットを読み込む一般的な手順は以下の通りです:
- モジュールをインポートしてください。
from torchvision import datasets
- データセットの変換の定義(オプション):
from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
画像をテンソルに変換し、正規化処理を行う。
- データセットを読み込む:
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
rootパラメーターは、データセットのダウンロードと保存先のパスを指定します。trainパラメーターはトレーニングセットを読み込むかテストセットを読み込むかを示し、transformパラメーターはデータセットへの変換を指定します。downloadパラメーターはデータセットをダウンロードするかどうかを示します(初回実行時のみダウンロードが必要)。
- データローダーを作成する。
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
ここでのbatch_sizeパラメータは、各バッチのサンプル数を指定し、shuffleパラメータはデータをランダムにシャッフルするかどうかを示します。
上記の手順に従うことで、PyTorchのデータセットを読み込んでトレーニングやテストに使用することができます。