torch.loadをPyTorchで使用する方法は何ですか?
PyTorchでは、torch.load()関数は保存されたモデルやテンソルをロードするために使用されます。基本的な構文は次のようになります:
torch.load(filepath, map_location=None, pickle_module=<module 'pickle' from '...'>)
- filepathはモデルやテンソルを保存するためのファイルパスです。
- map_locationは、モデル/テンソルをどの場所にロードするかを指定するオプションパラメータです。デバイス名(’cpu’、’cuda:0’など)を表す文字列であるか、torch.deviceオブジェクトであるかもしれません。デフォルト値はNoneで、保存時と同じデバイスにロードされます。
- pickle_moduleは、デフォルトのpickleモジュールを上書きするためのオプションパラメータです。デフォルト値はPythonの組み込みのpickleモジュールです。
torch.load()関数の使用例は以下の通りです:
import torch
# 加载保存的模型
model = torch.load('model.pth')
# 加载保存的张量
tensor = torch.load('tensor.pt')
# 加载保存的模型,并将其加载到指定设备上
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = torch.load('model.pth', map_location=device)
# 加载保存的模型,使用自定义的pickle模块
import pickle5 as pickle
model = torch.load('model.pth', pickle_module=pickle)
torch.load()関数は、同じバージョンのPyTorchで保存されたモデルやテンソルのみを読み込むことができます。異なるバージョンのPyTorchで保存されたモデルやテンソルを読み込む場合は、他の方法を使用して変換や読み込みを行う必要があります。