LSTM and GRU in PyTorch: Complete Guide
LSTM (Long Short-Term Memory) and GRU (Gated Recurrent Unit) in PyTorch are implemented through the torch.nn module. In PyTorch, you can create LSTM and GRU models using the torch.nn.LSTM and torch.nn.GRU classes.
Here is a simple example demonstrating how to use LSTM and GRU in PyTorch.
import torch
import torch.nn as nn
# 定义输入数据
input_size = 10
hidden_size = 20
seq_len = 5
batch_size = 3
input_data = torch.randn(seq_len, batch_size, input_size)
# 使用LSTM
lstm = nn.LSTM(input_size, hidden_size)
output, (h_n, c_n) = lstm(input_data)
print("LSTM output shape:", output.shape)
print("LSTM hidden state shape:", h_n.shape)
print("LSTM cell state shape:", c_n.shape)
# 使用GRU
gru = nn.GRU(input_size, hidden_size)
output, h_n = gru(input_data)
print("GRU output shape:", output.shape)
print("GRU hidden state shape:", h_n.shape)
In the example above, we begin by defining the dimensions of the input data and creating an LSTM and a GRU model using the torch.nn.LSTM and torch.nn.GRU classes, respectively. Next, we pass the input data to these models and examine the shapes of their outputs and hidden states.
It is worth noting that the output shapes of LSTM and GRU models may vary depending on the dimensions of the input data and the model’s parameter settings. Typically, the output shape will include information like sequence length, batch size, and the number of hidden units.