PyTorchでTransformerを呼び出す方法は何ですか?

PyTorchでは、torch.nn.Transformerクラスを使用して、Transformerモデルを呼び出すことができます。以下はTransformerモデルを使用したサンプルコードです。

import torch
import torch.nn as nn

# 定义Transformer模型
class TransformerModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads):
        super(TransformerModel, self).__init__()
        
        self.transformer = nn.Transformer(
            d_model=input_dim,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=hidden_dim
        )
        
    def forward(self, src, tgt):
        output = self.transformer(src, tgt)
        return output

# 创建Transformer模型实例
input_dim = 512
hidden_dim = 2048
num_layers = 6
num_heads = 8
model = TransformerModel(input_dim, hidden_dim, num_layers, num_heads)

# 准备输入数据
batch_size = 16
src_seq_len = 10
tgt_seq_len = 5
src = torch.randn(batch_size, src_seq_len, input_dim)
tgt = torch.randn(batch_size, tgt_seq_len, input_dim)

# 前向传播
output = model(src, tgt)

この例では、まず、nn.Moduleを継承したカスタムTransformerモデルクラスTransformerModelを定義します。__init__メソッドでは、nn.Transformerクラスを使用してTransformerモデルを作成し、入力次元、隠れ層の次元、エンコーダーおよびデコーダーのレイヤー数、およびアテンションのヘッド数を指定します。forwardメソッドでは、入力データをTransformerモデルに渡して順伝播させ、出力を返します。

その後、Transformerモデルのインスタンスを作成し、入力データを準備しました。最終的に、モデルのforwardメソッドを呼び出して前向き伝播を行い、出力結果を得ました。

bannerAds