PyTorchでは、モデルのパラメータを初期化する方法は何ですか?

PyTorchでは、モデルのパラメータを初期化するために関数を定義することができます。通常、PyTorchはtorch.nn.initモジュール内にいくつかの組み込みの初期化メソッドを提供しています。以下はよく使われる初期化方法の1つです:

import torch
import torch.nn as nn
import torch.nn.init as init

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(100, 10)

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

model = MyModel()
model.initialize_weights()

上記のコードでは、MyModelクラスを定義し、中にはnn.Linear(100, 10)という線形層が含まれています。モデルのパラメータを初期化するためにinitialize_weights関数を使用し、重みにはXavierの初期化方法を適用し、バイアスは0で初期化します。必要に応じて他の初期化方法を選択することもできます。

bannerAds