PyTorchで訓練済みモデルをロードする方法は?
PyTorchモデルを読み込むには、torch.load()関数を使用してモデルのパラメータと状態辞書を読み込むことができます。以下は、トレーニングされたモデルを読み込んで使用する例です。
import torch
import torchvision.models as models
# 实例化模型
model = models.resnet18()
# 加载训练好的模型参数
model.load_state_dict(torch.load('path_to_saved_model.pth'))
# 设置模型为评估模式
model.eval()
# 使用模型进行推理
inputs = torch.randn(1, 3, 224, 224)
outputs = model(inputs)
# 打印预测结果
print(outputs)
上記のコードでは、まずtorchvision.modelsモジュールを使用してResNet-18モデルのインスタンスを作成しました。次に、load_state_dict()関数を使用してトレーニングされたモデルのパラメータを読み込み、モデルのパラメータが保存されているファイルパスを指定する必要があります。その後、eval()メソッドを呼び出してモデルを評価モードに設定し、モデル内のいくつかのトレーニング固有の操作、例えばDropoutなどを無効にします。最後に、入力データをモデルに渡して推論を行い、予測結果を出力します。
モデルをロードする際には、モデルの構造がトレーニング時と完全に一致していることを確認する必要があります。そうでないと、ロードされたモデルのパラメーターにエラーが発生する可能性があります。