How to initialize model parameters in PyTorch?

In PyTorch, model parameters can be initialized by defining a function. Typically, PyTorch provides some built-in initialization methods, such as functions in the torch.nn.init module. One common initialization method is as follows:

import torch
import torch.nn as nn
import torch.nn.init as init

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(100, 10)

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

model = MyModel()
model.initialize_weights()

In the code above, we defined a class called MyModel, which includes a linear layer nn.Linear(100, 10). We use the initialize_weights function to initialize the model’s parameters, where we use the Xavier initialization method for weights and initialize biases to 0. You can also choose other initialization methods as needed.

广告
Closing in 10 seconds
bannerAds