nn.parameterのPyTorchでの使い方はどうですか?
PyTorchにおいて、nn.Parameterは特別なTensorであり、nn.Module内の訓練可能なパラメータの特別なタイプです。 nn.Parameterオブジェクトは、nn.Moduleのコンストラクタによって自動的に認識され、モデルの訓練可能なパラメータとして登録されます。
nn.Parameterを使用するためには、まずnn.Parameterオブジェクトを作成し、それをモデルの属性として設定する必要があります。以下は簡単な例です:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.weight = nn.Parameter(torch.rand(3, 4)) # 创建一个参数
def forward(self, x):
out = torch.matmul(x, self.weight)
return out
model = MyModel()
print(model.weight) # 打印参数
上記の例では、私たちはnn.Moduleを継承したMyModelクラスを定義しました。コンストラクタの__init__内で、形状が(3, 4)でランダムに初期化されたTensorであるnn.Parameterオブジェクトself.weightを作成しました。
forwardメソッド内では、self.weightを使用して計算することができます。モデルが作成された後は、model.weightを通じてこのパラメータにアクセスすることができます。
それに注意すべきは、nn.Parameterオブジェクトは自動的にモデルの学習可能なパラメーターとして登録され、モデルのparameters()メソッドでアクセスできることです。また、nn.Parameterオブジェクトには勾配計算の機能も自動で付与され、backward()メソッドで勾配を自動的に計算することができます。