PyTorchでのモデルの保存とロードの方法は何ですか?

PyTorchで、モデルをファイルとして保存するためにtorch.save()関数を使用し、保存したモデルファイルをロードするためにtorch.load()関数を使用することができます。以下はモデル保存とロードの例です:

import torch
import torch.nn as nn

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        x = self.fc(x)
        return x

model = Net()

# 保存模型
torch.save(model.state_dict(), 'model.pth')

# 加载模型
model.load_state_dict(torch.load('model.pth'))

上記のコードでは、model.state_dict()関数はモデルのパラメータ状態の辞書を取得し、それをtorch.save()関数を使用してファイルに保存します。モデルをロードする際には、保存したモデルファイルをtorch.load()関数でロードし、model.load_state_dict()関数を使用してモデルのパラメータをモデルにロードします。

モデルを保存する際は、モデルのパラメーターのみが保存され、モデルの構造は保存されません。モデルをロードする際には、まず同じモデル構造を作成してからパラメーターをロードする必要があります。

bannerAds