前回の記事はこちらです。
Rust用に変換したGPT-2の重みを使って、実際にTextGenerationタスクをしてみました。

※使用環境などは前回と同じです。

プロジェクトの作成

Rustではcargoコマンドでプロジェクトの作成や実行、公開などができるらしく、すごく便利だなぁと感動しました。
プロジェクトを作成します!

cargo new gpt2_test

すると、以下の構成になっていると思います

workspace
    ├── gpt2/
    ├── rust-bert/
    └── gpt2_test/
            ├──Cargo.toml
            └──src/
                └──main.rs

Cargo.tomlの編集

Cargo.tomlにはプロジェクトの情報や、依存関係を記述します。
今回は2つのクレートを使っているので、[dependencies]の下に追記します。

rust-bert = "0.17.0"
anyhow = "1.0.55"

※ 各クレートのバージョンやドキュメントはPythonでいうPyPIの、crates.ioというところから確認できます。

main.rsの編集

15行目あたりの、<モデルのパス>を、前回Rust用に変換したモデルの場所に置き換えてくだい。
(Rust始めて1ヶ月もたたない初心者なので、文法おかしいところがあると思いますが、動くので許してください…)

use std::path::PathBuf;
use rust_bert::resources::LocalResource;
use rust_bert::resources::Resource::Local;
use rust_bert::pipelines::text_generation::{ TextGenerationConfig, TextGenerationModel };
use rust_bert::resources::Resource;

fn input() -> String {
    let mut text = String::new();
    println!("Input: ");
    std::io::stdin().read_line(&mut text).unwrap();
    return text.trim().to_string();
}

fn get_resource(item: String) -> Resource {
    let mut model_dir = PathBuf::from("<モデルのパス>/gpt2/");
    model_dir.push(&item);
    println!("{:?}", model_dir);
    let resource = Local(LocalResource{
        local_path: model_dir,
    });
    return resource;
}

fn main() -> anyhow::Result<()> {

    let model_resource = get_resource(String::from("rust_model.ot"));
    let vocab_resource = get_resource(String::from("vocab.json")); 
    let config_resource = get_resource(String::from("config.json"));
    let merges_resource = get_resource(String::from("merges.txt"));

    // configの作成
    let generate_config = TextGenerationConfig {     
        model_type: rust_bert::pipelines::common::ModelType::GPT2,
        model_resource,
        config_resource,
        vocab_resource,
        merges_resource,

        // パラメーター調整
        repetition_penalty: 1.6,
        max_length: 30,
        do_sample: false,
        num_beams: 1,
        temperature: 1.0,
        ..Default::default()
    };
    
    // 上のconfigからモデル作成
    let model = TextGenerationModel::new(generate_config)?;
    //model.set_device(Device::cuda_if_available());
    loop {
        let input_text = input();
        // QUITで終了できるように
        if input_text == "QUIT" { break; }
        // 時間測定スタート
        let start = std::time::Instant::now();
        println!("Generating...");
        // 推論
        let output = model.generate(&[input_text], None);

        for sentence in output {
            println!("「{:?}」", sentence);
        }
        // 時間測定。差分を取る
        let stop = std::time::Instant::now();
        println!("<Time: {:.3}s>", (stop.duration_since(start).as_millis() as f64) / 1000.0);
        println!("\n");
    }
    Ok(())
}

実行!!

gpt2_test下で実行します。

cargo run

cargo runは、『コードをコンパイルして実行する』というお得なコマンドらしいです。すごい。
※コンパイルだけしたい場合は、cargo buildでいけます。

初回はコンパイルと、外部クレートのダウンロードがあるため、少々時間がかかります。
次回以降は、target/debug/に実行ファイルが生成されているので、

./target/debug/gpt2_test

でサクッと実行できます!

実行結果

Screen Shot 2022-03-07 at 19.08.20.png

推論時間:0.495秒

感想

rust-bertについて、GPT-2によるTextGenerationタスクを試してみました。
他にもQuestion AnsweringやTranslationも用意されているようです!
rust-bert 公式ドキュメント

日本語モデルでの生成はまだ上手く行っていないので、できたらまた記事にしようと思います。
環境構築も(libtorchを除けば)簡単だと思うので、ぜひ興味もって試してくれたら嬉しいです。

URLなど

crates.io
rust_bert
Rust 日本語ドキュメント

pythonで学習したDNNモデルをC++から利用する(PyTorch & libtorch版)
※libtorchのインストールで参考にさせていただきました。ありがとうございましたmm

bannerAds