How to initialize model parameters in PyTorch?
In PyTorch, model parameters can be initialized by defining a function. Typically, PyTorch provides some built-in initialization methods, such as functions in the torch.nn.init module. One common initialization method is as follows:
import torch
import torch.nn as nn
import torch.nn.init as init
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.Linear(100, 10)
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
model = MyModel()
model.initialize_weights()
In the code above, we defined a class called MyModel, which includes a linear layer nn.Linear(100, 10). We use the initialize_weights function to initialize the model’s parameters, where we use the Xavier initialization method for weights and initialize biases to 0. You can also choose other initialization methods as needed.