torch.load関数はPyTorchの中で何をするのですか?
torch.load関数は、保存されたPyTorchモデルやテンソルをハードディスクから読み込むために使用されます。.pth、.pt、.pklなどのファイルを読み込み、モデルの重み、ネットワーク構造、トレーニング状態などの情報を含むPython辞書を返します。
torch.load関数を使用すると、事前に学習済みモデルを簡単に読み込んで、新しいタスクで微調整や推論を行うことができます。読み込んだモデルは評価や予測の生成、またはさらなるトレーニングに活用できます。
使用例:
model = torch.load('model.pth')
さらに、torch.load関数は、指定されたデバイスにモデルをロードするためにmap_locationパラメータを指定することもできます。例えば、モデルをGPUにロードすることができます。
model = torch.load('model.pth', map_location='cuda:0')