PyTorchでのモデルの保存とロードの方法は何ですか?
PyTorchで、モデルをファイルとして保存するためにtorch.save()関数を使用し、保存したモデルファイルをロードするためにtorch.load()関数を使用することができます。以下はモデル保存とロードの例です:
import torch
import torch.nn as nn
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
x = self.fc(x)
return x
model = Net()
# 保存模型
torch.save(model.state_dict(), 'model.pth')
# 加载模型
model.load_state_dict(torch.load('model.pth'))
上記のコードでは、model.state_dict()関数はモデルのパラメータ状態の辞書を取得し、それをtorch.save()関数を使用してファイルに保存します。モデルをロードする際には、保存したモデルファイルをtorch.load()関数でロードし、model.load_state_dict()関数を使用してモデルのパラメータをモデルにロードします。
モデルを保存する際は、モデルのパラメーターのみが保存され、モデルの構造は保存されません。モデルをロードする際には、まず同じモデル構造を作成してからパラメーターをロードする必要があります。