PyTorchにおけるリカレントニューラルネットワークの実装方法は何ですか?

PyTorchでは、リカレントニューラルネットワーク(RNN)はtorch.nn.RNNやtorch.nn.LSTMなどのモジュールを使用して実装することができます。これらのモジュールはすべてtorch.nn.Moduleクラスを継承しており、内部でRNNの計算プロセスをカプセル化しています。

PyTorchで基本的なリカレントニューラルネットワークモデルを作成する方法を示す簡単な例を以下に示します。

import torch
import torch.nn as nn

# 定义RNN模型
class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        
    def forward(self, x, h0):
        out, hn = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])  # 只取最后一个时间步的输出作为预测结果
        return out

# 定义输入参数
input_size = 28
hidden_size = 128
num_layers = 1
output_size = 10

# 创建模型实例
rnn = RNN(input_size, hidden_size, num_layers, output_size)

# 定义输入数据
x = torch.randn(64, 10, 28)  # (batch_size, sequence_length, input_size)
h0 = torch.zeros(num_layers, x.size(0), hidden_size)  # 初始隐藏状态

# 前向传播
output = rnn(x, h0)
print(output.shape)  # 输出的形状为(batch_size, output_size)

上のコードでは、まずRNNクラスを定義し、それはnn.Moduleを継承し、コンストラクタでRNNモデルの各レイヤーを初期化します。次に、forwardメソッドでRNNの順伝播計算を実行し、予測結果として最後のタイムステップの出力を返します。最後に、モデルのインスタンスを作成し、入力データを定義し、順伝播計算を行います。

PyTorchには、nn.LSTMやnn.GRUなど、さまざまなタイプのリカレントニューラルネットワークモジュールがあります。開発者は必要に応じて適切なモジュールを選択して、自分自身のリカレントニューラルネットワークモデルを構築することができます。

bannerAds