Update README.md
Browse files
README.md
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
# Twister
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
---
|
| 6 |
## Performance
|
|
@@ -9,9 +12,9 @@
|
|
| 9 |
| Dataset\Model | WLV2-Oracle ↓ | WLV2-Auto ↓ | WLV3-Auto ↓ | COOL-Whisper ↓ | Twister (Ours) ↓ |
|
| 10 |
|---------------------------|---------------|-------------|-------------|----------------|------------------|
|
| 11 |
| ASCEND-OVERALL* | 21.14 (AUTO) | 21.14 | 23.22 | 19.71 | **17.74** (-16.08%) |
|
| 12 |
-
|
|
| 13 |
-
|
|
| 14 |
-
|
|
| 15 |
| CommonVoice16-zh-TW | 9.02 (ZH) | 9.84 | 8.95 | 11.86 | **7.97** (-19%) |
|
| 16 |
| CSZS-zh-en* | 29.49 (AUTO) | 29.49 | 26.43 | 20.90 | **13.01** (-55.88%) |
|
| 17 |
|
|
@@ -44,42 +47,7 @@ Twister is fine-tuned on about **4,000 hours** of high-quality speech synthesize
|
|
| 44 |
|
| 45 |
## 🔧 Usage Example
|
| 46 |
|
| 47 |
-
|
| 48 |
-
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
| 49 |
-
from datasets import load_dataset
|
| 50 |
-
import torch
|
| 51 |
-
|
| 52 |
-
# 1. Load model and processor
|
| 53 |
-
processor = WhisperProcessor.from_pretrained("Mediatek-Research/Twister")
|
| 54 |
-
model = WhisperForConditionalGeneration.from_pretrained("Mediatek-Research/Twister")
|
| 55 |
-
model.eval()
|
| 56 |
-
|
| 57 |
-
# 2. Set decoding prompt for Chinese transcription
|
| 58 |
-
forced_decoder_ids = processor.get_decoder_prompt_ids(language="zh", task="transcribe")
|
| 59 |
-
|
| 60 |
-
# 3. Load a sample dataset
|
| 61 |
-
ds = load_dataset("ky552/ML2021_ASR_ST", split="test")
|
| 62 |
-
sample = ds[0]["audio"]
|
| 63 |
-
|
| 64 |
-
# 4. Preprocess input audio
|
| 65 |
-
inputs = processor(
|
| 66 |
-
sample["array"],
|
| 67 |
-
sampling_rate=sample["sampling_rate"],
|
| 68 |
-
return_tensors="pt"
|
| 69 |
-
)
|
| 70 |
-
input_features = inputs.input_features
|
| 71 |
-
|
| 72 |
-
# 5. Inference
|
| 73 |
-
with torch.no_grad():
|
| 74 |
-
predicted_ids = model.generate(
|
| 75 |
-
input_features,
|
| 76 |
-
forced_decoder_ids=forced_decoder_ids
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
# 6. Decode prediction
|
| 80 |
-
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
| 81 |
-
print("Transcription:", transcription)
|
| 82 |
-
```
|
| 83 |
|
| 84 |
```python
|
| 85 |
import torchaudio
|
|
@@ -87,7 +55,7 @@ import torch
|
|
| 87 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutomaticSpeechRecognitionPipeline
|
| 88 |
|
| 89 |
# 1. Load audio
|
| 90 |
-
audio_path = "./
|
| 91 |
waveform, sample_rate = torchaudio.load(audio_path)
|
| 92 |
|
| 93 |
# 2. Preprocess
|
|
@@ -116,14 +84,64 @@ asr_pipeline = AutomaticSpeechRecognitionPipeline(
|
|
| 116 |
output = asr_pipeline(waveform)
|
| 117 |
print("Result:", output["text"])
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
# Whipser: 使用這個方式的時候
|
| 120 |
# Twister: 使用這個 function 的時候(correct)
|
| 121 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
---
|
| 123 |
|
| 124 |
## 📜 Citation
|
| 125 |
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
> *Equal contribution
|
| 129 |
-
> [Paper](https://youtu.be/dQw4w9WgXcQ?si=p5zCM8Hys4FdEOK0)
|
|
|
|
| 1 |
# Twister
|
| 2 |
|
| 3 |
+
|
| 4 |
+
**Twister** 是一個針對繁體中文以及中英交錯情境進行優化的語音辨識模型。**Twister** 基於 Whisper-large-v2 之上訓練而成,其中文部分完全採用合成語音資料進行訓練。
|
| 5 |
+
|
| 6 |
+
**Twister** is an advanced ASR model fine-tuned from [Whisper-large-v2](https://github.com/openai/whisper) with TTS-sythesized data, specially optimized for Taiwanese Mandarin and Mandarin-English code-switching scenarios.
|
| 7 |
|
| 8 |
---
|
| 9 |
## Performance
|
|
|
|
| 12 |
| Dataset\Model | WLV2-Oracle ↓ | WLV2-Auto ↓ | WLV3-Auto ↓ | COOL-Whisper ↓ | Twister (Ours) ↓ |
|
| 13 |
|---------------------------|---------------|-------------|-------------|----------------|------------------|
|
| 14 |
| ASCEND-OVERALL* | 21.14 (AUTO) | 21.14 | 23.22 | 19.71 | **17.74** (-16.08%) |
|
| 15 |
+
| - ASCEND-EN | 27.20 (EN) | 27.36 | 27.21 | 29.39 | **26.64** (-2.63%) |
|
| 16 |
+
| - ASCEND-ZH | **13.75** (ZH)| 17.49 | 17.41 | 18.90 | 16.04 (-8.29%) |
|
| 17 |
+
| - ASCEND-MIX* | 21.01 (AUTO) | 21.01 | 25.13 | 17.34 | **16.38** (-22.01%) |
|
| 18 |
| CommonVoice16-zh-TW | 9.02 (ZH) | 9.84 | 8.95 | 11.86 | **7.97** (-19%) |
|
| 19 |
| CSZS-zh-en* | 29.49 (AUTO) | 29.49 | 26.43 | 20.90 | **13.01** (-55.88%) |
|
| 20 |
|
|
|
|
| 47 |
|
| 48 |
## 🔧 Usage Example
|
| 49 |
|
| 50 |
+
To run the model on `input_audio.wav`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
```python
|
| 53 |
import torchaudio
|
|
|
|
| 55 |
from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutomaticSpeechRecognitionPipeline
|
| 56 |
|
| 57 |
# 1. Load audio
|
| 58 |
+
audio_path = "./input_audio.wav"
|
| 59 |
waveform, sample_rate = torchaudio.load(audio_path)
|
| 60 |
|
| 61 |
# 2. Preprocess
|
|
|
|
| 84 |
output = asr_pipeline(waveform)
|
| 85 |
print("Result:", output["text"])
|
| 86 |
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
You can obtain a wav file for testing by loading from a benchmark:
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
from datasets import load_dataset
|
| 93 |
+
import torch
|
| 94 |
+
import torchaudio
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
ds = load_dataset("ky552/ML2021_ASR_ST", split="test")
|
| 98 |
+
sample = ds[0]["audio"]
|
| 99 |
+
|
| 100 |
+
audio_array = sample["array"]
|
| 101 |
+
sampling_rate = sample["sampling_rate"]
|
| 102 |
+
|
| 103 |
+
waveform = torch.tensor(audio_array).unsqueeze(0)
|
| 104 |
+
|
| 105 |
+
torchaudio.save("input_audio.wav", waveform, sampling_rate)
|
| 106 |
+
|
| 107 |
+
# Decoding Results:
|
| 108 |
# Whipser: 使用這個方式的時候
|
| 109 |
# Twister: 使用這個 function 的時候(correct)
|
| 110 |
```
|
| 111 |
+
---
|
| 112 |
+
## Training Data
|
| 113 |
+
|
| 114 |
+
Twister 的訓練採樣自以下數據集:
|
| 115 |
+
|
| 116 |
+
The training data of Twister is sample the following publicly available sources:
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
| Dataset Name | Type | Language | Total Hours | License |
|
| 120 |
+
|------------------------------------------------------------------------------|--------|-----------------|-------------|---------|
|
| 121 |
+
| ODC Synth | Synth. | Mandarin | 10,000 | Open Data Commons License Attribution + Apache2.0* |
|
| 122 |
+
| [CommonVoice17-EN](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0) | Real | English | 1,738 | Creative Commons Zero |
|
| 123 |
+
| [NTUML2021](https://huggingface.co/datasets/ky552/ML2021_ASR_ST) | Real | Code-switching | 11 | MIT License |
|
| 124 |
+
|
| 125 |
+
*ODC Synth is generated by using text from [FineWeb2](https://huggingface.co/datasets/HuggingFaceFW/fineweb-2) (ODC License) and a TTS model [BreezyVoice](https://huggingface.co/MediaTek-Research/BreezyVoice) (Apache2.0 License)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
---
|
| 129 |
|
| 130 |
## 📜 Citation
|
| 131 |
|
| 132 |
+
If you find this model useful, please cite our work:
|
| 133 |
+
|
| 134 |
+
**Cheng-Kang Chou\***, **Chan-Jan Hsu\***, Ho-Lam Chung, Liang-Hsuan Tseng, Hsi-Chun Cheng, Yu-Kuan Fu, Kuan-Po Huang, Hung-yi Lee
|
| 135 |
+
[*A Self-Refining Framework for Enhancing ASR Using TTS-Synthesized Data*](https://arxiv.org/pdf/2506.11130)
|
| 136 |
+
|
| 137 |
+
\*Equal contribution
|
| 138 |
+
|
| 139 |
+
```bibtex
|
| 140 |
+
@article{chou2025selfrefiningframeworkenhancingasr,
|
| 141 |
+
title={A Self-Refining Framework for Enhancing ASR Using TTS-Synthesized Data},
|
| 142 |
+
author={Cheng Kang Chou and Chan-Jan Hsu and Ho-Lam Chung and Liang-Hsuan Tseng and Hsi-Chun Cheng and Yu-Kuan Fu and Kuan Po Huang and Hung-Yi Lee},
|
| 143 |
+
journal={arXiv preprint arXiv:2506.11130},
|
| 144 |
+
year={2025}
|
| 145 |
+
}
|
| 146 |
+
```
|
| 147 |
|
|
|
|
|
|