PyTorchにおけるモデルの保存とロードの方法は何ですか?
PyTorchには、モデルを保存して読み込むための関数であるtorch.save()とtorch.load()が提供されています。
- モデルの保存:torch.save(model.state_dict(), PATH)関数を使用すると、モデルのパラメータを指定したPATHに保存することができます。
- モデルのロード:
最初に、元のモデル構造と同じ空のモデルを作成する必要があります。 - model = ModelClass(*args, **kwargs) # 空のモデルインスタンスを作成
- その後、torch.load()関数を使用して保存されたモデルパラメーターをロードし、空のモデルに割り当てます。
- model.load_state_dict(torch.load(PATH)) –> model.load_state_dict(torch.load(PATH)を読み込む)
- 最後に、ロードされたモデルを使用して予測またはトレーニングを行うことができます。
モデルを保存および読み込む際には、モデルの構造とパラメータの形状が一致していることを確認する必要があります。違うとエラーが発生する恐れがあります。