PyTorchでデータローダーを使用してデータを読み込む方法は何ですか?
PyTorchでは、データを読み込むためにtorch.utils.data.DataLoaderクラスを使用することができます。DataLoaderは、データセットを小さなバッチに分割して読み込む 可変長のデータローダーを提供し、トレーニングを容易にします。
以下は、DataLoaderを使用してデータを読み込む例です。
- 必要なライブラリをインポートする。
import torch
from torch.utils.data import DataLoader
- データセット
- torch.utils.data.Datasetを日本語で言い換えると、「torch.utils.data.Dataset」です。
- __len__ メソッド
- __getitem__を使う
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
- データセット
dataset = CustomDataset(data)
- データローダー
- データセット
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
- データローダー
- 以下を日本語で自然に言い換えると、選択肢は1つだけです:「列挙する」
→「挙げる」
for i, batch in enumerate(dataloader):
inputs = batch
# 在这里执行模型的前向传播和训练操作
DataLoaderはデータのバッチを返すことに注意が必要です。各サンプルのインデックスを取得したい場合は、enumerate関数を使用して取得できます。上記の例では、batchはサイズ32のバッチであり、inputsはそのバッチのデータになります。
願わくば、お役に立てますように!