rnnt-trf-v1 (kairo-ai)

ローマ字入力列から日本語文(漢字かな交じり文)への変換を行う RNN-Transducer モデルです。 タイプミス(重複・脱字・隣接キー誤打・転置)を含む入力に対しても頑健に変換できるよう、ノイズを付与したデータで学習しています。

A RNN-Transducer model that converts romaji input sequences into Japanese text (kanji-kana mixed sentences), trained to be robust against typos.

モデル概要

  • タスク: ローマ字 → 日本語文変換(例: wagahaihanekodearu.吾輩は猫である。
  • アーキテクチャ: RNN-T(エンコーダ・予測ネットワークともに Transformer)
    • エンコーダ: Transformer 4層
    • 予測ネットワーク (predictor): Transformer 2層
    • 埋め込み次元 / 隠れ次元: 256 / 256
    • アテンションヘッド数: 4、dropout: 0.1
    • 最大位置数: 512、最大系列長: 256
    • パラメータ数: 約 7.6M(7,567,509)
    • 語彙: 入力 494 トークン(文字単位)、出力 4,501 トークン(文字単位、最低出現頻度 2)
  • 損失関数: RNN-T loss
  • 最適化: Adam(ピーク学習率 3e-4、cosine スケジューラ、warmup 比率 0.05、weight decay 0.01、gradient clip 1.0)
  • バッチサイズ: 8
  • 学習環境: NVIDIA RTX 6000 Ada(48GB)、20 エポックで約 10 時間
  • 公開チェックポイント: epoch 20 時点(epoch_020.pt
    • 学習は 50 エポックの設定で開始したが、計算環境の都合により epoch 20 で中断。中断時点でも train / valid loss・CER は改善傾向にあり、継続学習によるさらなる精度向上の余地がある。
  • ソースコード: github.com/takumiecd/kairo-ai(MIT License)

評価結果

テストセット最終評価(未使用データ 30,144 ペア、greedy デコード)

指標
corpus CER 7.05%
文単位完全一致率 45.4%

ノイズ種別ごとの CER:

ノイズ種別 件数 CER
none(ノイズなし) 14,910 5.99%
duplication(文字重複) 3,628 7.02%
romaji_variant(表記ゆれ) 1,100 7.15%
keyboard_typo(隣接キー誤打) 3,555 8.45%
swap(隣接文字転置) 3,400 8.54%
deletion(脱字) 3,551 8.65%

学習中の推移

Epoch train loss valid loss valid CER
5 9.91 8.26 10.48%
10 6.63 6.01 7.64%
15 5.22 5.16 6.59%
20 4.22 4.47 5.87%

学習中の valid CER は検証データ 100 サンプルの抽出値(5 エポックごと、greedy)。テスト全件 CER(7.05%)との差は少数サンプルによるサンプリング誤差であり、最終性能はテストセット全件評価を参照のこと。

学習データ

301,429 ペア(input: ローマ字列, target: 日本語文)。約半数にタイプミスを模したノイズを付与。

ソース構成

ソース ペア数 割合 ライセンス
日本語 Wikipedia(記事ダンプ) 199,929 66.3% CC BY-SA / GFDL
Tatoeba(日本語例文) 100,000 33.2% CC BY 2.0 FR
青空文庫(夏目漱石『吾輩は猫である』) 1,500 0.5% パブリックドメイン(確認済み)

ノイズ構成

ノイズ種別 ペア数 割合
none(ノイズなし) 150,500 49.9%
duplication(文字重複) 35,546 11.8%
deletion(脱字) 35,518 11.8%
keyboard_typo(隣接キー誤打) 35,313 11.7%
swap(隣接文字転置) 33,646 11.2%
romaji_variant(ローマ字表記ゆれ) 10,906 3.6%

ファイル構成

train.rnnt.train が出力する artifact ディレクトリと同じ構成です。hf download でそのまま --artifact-dir に使えます。

ファイル 内容
checkpoints/best.pt 検証損失が最小のチェックポイント(= epoch 20。decode のデフォルト)
checkpoints/epoch_020.pt epoch 20 のチェックポイント(model / optimizer / scheduler の state と学習 config を含む)
config.json モデル・学習設定(チェックポイント内 config と同一)
input_vocab.json 入力(ローマ字)語彙 494 トークン
output_vocab.json 出力(日本語文字)語彙 4,501 トークン
metrics.jsonl エポックごとの train/valid loss・CER・学習率
loss_curve.png / cer_curve.png 学習曲線

使い方

本モデルは独自アーキテクチャのため、transformersfrom_pretrained には対応していません。 kairo-ai のコードと併せて使用します。

git clone https://github.com/takumiecd/kairo-ai
cd kairo-ai && uv sync

# モデル一式(checkpoint + config + vocab)をダウンロード
hf download takumiecd/kairo-rnnt-trf-v1 --local-dir artifacts/rnnt-trf-v1

# greedy デコードで推論(デフォルトで checkpoints/best.pt が使われる)
uv run python -m decode.greedy \
  --artifact-dir artifacts/rnnt-trf-v1 \
  --input "wagahaihanekodearu."
# => 吾輩は猫である。

再現手順

本モデルの構築に使用したコマンド。詳細なフラグの説明は kairo-ai リポジトリ の README を参照。

1. データセット構築

# Wikipedia(10万unit × augmentations 1 ≒ 20万ペア)
python -m dataset.source_wikipedia \
  --dump data/raw/wiki/jawiki-latest-pages-articles.xml.bz2 \
  --output data/external/wiki_ja.jsonl \
  --license cc_by_sa_gfdl \
  --max-units 100000 \
  --augmentations 1

# Tatoeba(5万unit × augmentations 1 ≒ 10万ペア)
python -m dataset.source_tatoeba \
  --sentences data/raw/tatoeba/sentences.tar.bz2 \
  --output data/external/tatoeba_ja.jsonl \
  --lang jpn \
  --max-units 50000 \
  --augmentations 1

# 青空文庫『吾輩は猫である』(500unit × augmentations 2 = 1,500ペア)
python -m dataset.source_text \
  --source https://www.aozora.gr.jp/cards/000148/files/789_ruby_5639.zip \
  --output data/external/aozora_wagahai.jsonl \
  --source-name aozora \
  --license aozora_public_domain_checked \
  --format aozora \
  --max-units 500 \
  --augmentations 2

2. 結合と分割(train:valid:test = 8:1:1, seed 0 → 241,143 / 30,142 / 30,144 ペア)

cat data/external/wiki_ja.jsonl \
    data/external/tatoeba_ja.jsonl \
    data/external/aozora_wagahai.jsonl > data/combined/all_sources.jsonl

python -m dataset.split \
  --input data/combined/all_sources.jsonl \
  --output-dir data/combined/all_sources

3. 学習(チェックポイント内 config より)

python -m train.rnnt.train \
  --data data/combined/all_sources/train.jsonl \
  --valid-data data/combined/all_sources/valid.jsonl \
  --output-dir artifacts/rnnt-trf-v1 \
  --encoder-type transformer --prediction-type transformer \
  --encoder-layers 4 --prediction-layers 2 \
  --embed-dim 256 --hidden-dim 256 --num-heads 4 \
  --epochs 50 --batch-size 8 \
  --learning-rate 3e-4 \
  --lr-scheduler cosine --warmup-ratio 0.05 \
  --device cuda \
  --valid-decode greedy --valid-cer-samples 100 --valid-cer-every 5 \
  --max-positions 512

50 エポック設定だが、計算環境の都合により epoch 20 で中断(前述)。

4. 評価

python -m eval.run_test \
  --artifact-dir artifacts/rnnt-trf-v1 \
  --data data/combined/all_sources/test.jsonl \
  --device cuda

推論例(epoch 20, greedy, CPU で動作確認済み)

入力(ローマ字) 出力
wagahaihanekodearu. 吾輩は猫である。
wagaahaihanekodearu.(重複ノイズ) 吾輩は猫である。
namaehamdaanai.(転置ノイズ) 名前はまだない。
kyouhayoitenkidesune. 今日はよい天気ですね。
kikaigakushuunorepo-towokakimasu. 機械学習のレポートを書きます。

ライセンスと帰属

  • 学習データに CC BY-SA の Wikipedia 由来テキストを含むため、本モデルは CC BY-SA 4.0 で公開しています。
  • Tatoeba の例文は tatoeba.org より CC BY 2.0 FR で提供されています。
  • 青空文庫のテキストは 青空文庫 より取得したパブリックドメイン作品です。

出典

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support