nn.parameterのPyTorchでの使い方はどうですか?

PyTorchにおいて、nn.Parameterは特別なTensorであり、nn.Module内の訓練可能なパラメータの特別なタイプです。 nn.Parameterオブジェクトは、nn.Moduleのコンストラクタによって自動的に認識され、モデルの訓練可能なパラメータとして登録されます。

nn.Parameterを使用するためには、まずnn.Parameterオブジェクトを作成し、それをモデルの属性として設定する必要があります。以下は簡単な例です:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.weight = nn.Parameter(torch.rand(3, 4))  # 创建一个参数

    def forward(self, x):
        out = torch.matmul(x, self.weight)
        return out

model = MyModel()
print(model.weight)  # 打印参数

上記の例では、私たちはnn.Moduleを継承したMyModelクラスを定義しました。コンストラクタの__init__内で、形状が(3, 4)でランダムに初期化されたTensorであるnn.Parameterオブジェクトself.weightを作成しました。

forwardメソッド内では、self.weightを使用して計算することができます。モデルが作成された後は、model.weightを通じてこのパラメータにアクセスすることができます。

それに注意すべきは、nn.Parameterオブジェクトは自動的にモデルの学習可能なパラメーターとして登録され、モデルのparameters()メソッドでアクセスできることです。また、nn.Parameterオブジェクトには勾配計算の機能も自動で付与され、backward()メソッドで勾配を自動的に計算することができます。

bannerAds