PyTorchのバッチ予測方法は何ですか?
PyTorchを使用したバッチ予測の方法では、通常、DataLoaderを使用してデータバッチをロードし、それらをモデルに渡して推論を行う手順が含まれます。具体的な手順は以下の通りです:
- データセットを作成するには、まずtorch.utils.data.Datasetを継承し、__len__と__getitem__メソッドを実装するカスタムデータセットクラスを構築する必要があります。これによりデータセットの長さとデータサンプルを返すことができます。
- torch.utils.data.DataLoaderクラスを使用してデータローダーを作成します。これにより、データをバッチサイズに分割して処理することができます。データローダーを作成する際には、使用するデータセット、バッチサイズ、データのシャッフルの有無などのパラメータを指定する必要があります。
- PyTorchモデルをロードして、トレーニング済みのモデルを読み込むことができます。torch.loadを使って重みだけでなく、モデル全体をロードすることができます。
- バッチ予測:ロードされたモデルを使用してデータを一括で予測します。各データバッチに対して、予測結果を取得するためには、モデル.forward()メソッドを使用する必要があります。
以下是一个简单示例代码:
import torch
from torch.utils.data import DataLoader
# 1. 构建数据集类
class MyDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 2. 创建数据加载器
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
dataset = MyDataset(data)
dataloader = DataLoader(dataset, batch_size=3, shuffle=False)
# 3. 加载模型
model = torch.load('model.pth')
# 4. 批量预测
predictions = []
for batch in dataloader:
inputs = batch # 根据自定义的数据集类,每个batch都是一个样本
outputs = model(inputs)
predictions.extend(outputs.tolist())
上記の例では、数字1から10までを含むデータセットクラスMyDatasetを構築しました。その後、データセットをバッチごとに分割するdataloaderを作成し、各バッチには3つのサンプルが含まれます。次に、トレーニング済みのモデルmodelをロードし、データローダーを使用してデータを一括予測しました。最後に、予測結果はpredictionsリストに保存されました。