torch.loadをPyTorchで使用する方法は何ですか?

PyTorchでは、torch.load()関数は保存されたモデルやテンソルをロードするために使用されます。基本的な構文は次のようになります:

torch.load(filepath, map_location=None, pickle_module=<module 'pickle' from '...'>)
  1. filepathはモデルやテンソルを保存するためのファイルパスです。
  2. map_locationは、モデル/テンソルをどの場所にロードするかを指定するオプションパラメータです。デバイス名(’cpu’、’cuda:0’など)を表す文字列であるか、torch.deviceオブジェクトであるかもしれません。デフォルト値はNoneで、保存時と同じデバイスにロードされます。
  3. pickle_moduleは、デフォルトのpickleモジュールを上書きするためのオプションパラメータです。デフォルト値はPythonの組み込みのpickleモジュールです。

torch.load()関数の使用例は以下の通りです:

import torch

# 加载保存的模型
model = torch.load('model.pth')

# 加载保存的张量
tensor = torch.load('tensor.pt')

# 加载保存的模型,并将其加载到指定设备上
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = torch.load('model.pth', map_location=device)

# 加载保存的模型,使用自定义的pickle模块
import pickle5 as pickle
model = torch.load('model.pth', pickle_module=pickle)

torch.load()関数は、同じバージョンのPyTorchで保存されたモデルやテンソルのみを読み込むことができます。異なるバージョンのPyTorchで保存されたモデルやテンソルを読み込む場合は、他の方法を使用して変換や読み込みを行う必要があります。

bannerAds