PyTorchのバッチ予測方法は何ですか?

PyTorchを使用したバッチ予測の方法では、通常、DataLoaderを使用してデータバッチをロードし、それらをモデルに渡して推論を行う手順が含まれます。具体的な手順は以下の通りです:

  1. データセットを作成するには、まずtorch.utils.data.Datasetを継承し、__len__と__getitem__メソッドを実装するカスタムデータセットクラスを構築する必要があります。これによりデータセットの長さとデータサンプルを返すことができます。
  2. torch.utils.data.DataLoaderクラスを使用してデータローダーを作成します。これにより、データをバッチサイズに分割して処理することができます。データローダーを作成する際には、使用するデータセット、バッチサイズ、データのシャッフルの有無などのパラメータを指定する必要があります。
  3. PyTorchモデルをロードして、トレーニング済みのモデルを読み込むことができます。torch.loadを使って重みだけでなく、モデル全体をロードすることができます。
  4. バッチ予測:ロードされたモデルを使用してデータを一括で予測します。各データバッチに対して、予測結果を取得するためには、モデル.forward()メソッドを使用する必要があります。

以下是一个简单示例代码:

import torch
from torch.utils.data import DataLoader

# 1. 构建数据集类
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 2. 创建数据加载器
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
dataset = MyDataset(data)
dataloader = DataLoader(dataset, batch_size=3, shuffle=False)

# 3. 加载模型
model = torch.load('model.pth')

# 4. 批量预测
predictions = []
for batch in dataloader:
    inputs = batch  # 根据自定义的数据集类,每个batch都是一个样本
    outputs = model(inputs)
    predictions.extend(outputs.tolist())

上記の例では、数字1から10までを含むデータセットクラスMyDatasetを構築しました。その後、データセットをバッチごとに分割するdataloaderを作成し、各バッチには3つのサンプルが含まれます。次に、トレーニング済みのモデルmodelをロードし、データローダーを使用してデータを一括予測しました。最後に、予測結果はpredictionsリストに保存されました。

bannerAds