PyTorchモデルの保存と読み込み方法は?

PyTorchモデルは、保存と読み込みを次の方法で行うことができます。

モデルを保存する:

# 保存整个模型
torch.save(model, 'model.pth')

# 保存模型的state_dict
torch.save(model.state_dict(), 'model_state_dict.pth')

モデルの読み込み:

# 加载整个模型
model = torch.load('model.pth')

# 创建模型实例并加载state_dict
model = Model()
model.load_state_dict(torch.load('model_state_dict.pth'))

モデルをロードするときには、モデルの構造が保存されたものと同じであることを確認してください。state_dictのみ保存されている場合は、まずモデルのインスタンスを作成してからstate_dictをロードする必要があります。

bannerAds