Audio-Text-to-Text
Transformers
Safetensors
qwen2_audio
text2text-generation
Inference Endpoints
File size: 5,580 Bytes
7662fdc
 
c1773c1
7662fdc
c1773c1
7662fdc
 
5123f60
7662fdc
 
 
4011eae
 
c0e9ff4
5123f60
d0ac662
7662fdc
df93550
 
 
 
 
 
 
 
 
 
 
c0e9ff4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5123f60
c0e9ff4
 
 
 
 
 
 
7662fdc
26fcbc8
 
 
 
 
7662fdc
b74b317
26fcbc8
 
 
7662fdc
b74b317
 
2d1002f
8805704
 
 
b74b317
cefa606
b74b317
 
 
26fcbc8
 
b74b317
 
26fcbc8
 
cefa606
7662fdc
b74b317
26fcbc8
 
 
 
b74b317
 
c1773c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
---
library_name: transformers
license: apache-2.0
tags: []
pipeline_tag: audio-text-to-text
---

# R1-AQA --- Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering

<!-- Provide a quick summary of what the model is/does. -->

## Introduction

R1-AQA is a audio question answering (AQA) model based on `Qwen2-Audio-7B-Instruct`, optimized through reinforcement learning using the group relative policy optimization (GRPO) algorithm. 
This implementation has achieved state-of-the-art performance on MMAU *Test-mini* benchmark with only 38k post-training samples.
For more details, please refer to our [Github](https://github.com/xiaomi-research/r1-aqa) and [Technical Report](https://arxiv.org/abs/2503.11197).

Our main findings are as follows:

- The GRPO algorithm can be directly and effectively applied to the audio modality, even to `Qwen2-Audio-7B-Instruct` with only 8.2B parameters.
- With only 38k post-training samples, reinforcement learning outperforms supervised fine-tuning, indicating that RL-based approaches can be effective without large datasets.
- The explicit reasoning process has not shown significant benefits for AQA tasks, and how to efficiently leverage *deep thinking* or step-by-step reasoning remains an open question for further research.
- Large audio language models (LALMs) still lag far behind humans auditory-language reasoning, suggesting that the RL-based approaches warrant further explorations.

Additional Notes:  
The AVQA training set originally consists of approximately 40k samples. However, we use only about 38k samples because some data sources have become invalid. Other datasets using YouTube sources face a similar issue, such as AudioSet. We believe that the missing 2k samples do not have a significant impact on the training results.


### Table: Accuracies (%) on MMAU Test-mini benchmark
| Model                                      | Method                  | Sound  | Music  | Speech | Average |
|--------------------------------------------|-------------------------|--------|--------|--------|---------|
| \                                          | Human\*                 | 86.31  | 78.22  | 82.17  | 82.23   |
| Gemini Pro 2.0 Flash                       | Direct Inference\*      | 56.46  | 58.68  | 51.65  | 55.60   |
| Audio Flamingo 2                           | Direct Inference\*      | 61.56  | **73.95** | 30.93  | 55.48   |
| GPT4o + Strong Cap.                        | Direct Inference\*      | 57.35  | 49.70  | **64.86** | 57.30   |
| Llama-3-8B-Instruct + Strong Cap.          | Direct Inference\*      | 50.75  | 48.93  | 55.25  | 52.10   |
| Gemini Pro v1.5                            | Direct Inference\*      | 56.75  | 49.40  | 58.55  | 54.90   |
| Qwen2-Audio-7B-Instruct                    | Direct Inference\*      | 54.95  | 50.98  | 42.04  | 49.20   |
| GPT4o + Weak Cap.                          | Direct Inference\*      | 39.33  | 41.90  | 58.25  | 45.70   |
| Llama-3-8B-Instruct + Weak Cap.            | Direct Inference\*      | 34.23  | 38.02  | 54.05  | 42.10   |
| SALMONN                                    | Direct Inference\*      | 41.00  | 34.80  | 25.50  | 33.70   |
| Qwen2-Audio-7B-Instruct                    | CoTA \[1\]            | 60.06  | 64.30  | 60.70  | 61.71   |
| Qwen2-Audio-7B-Instruct                    | Zero-Shot-CoT \[2\]   | 61.86  | 56.29  | 55.26  | 57.80   |
| **Qwen2-Audio-7B-Instruct**                | **GRPO (Ours)**         | **69.37** | 66.77  | 57.36  | **64.50** |

#### Notes:
\* The data are sourced from the MMAU official website: [https://sakshi113.github.io/mmau_homepage/](https://sakshi113.github.io/mmau_homepage/)  
\[1\] Xie, Zhifei, et al. "Audio-Reasoner: Improving Reasoning Capability in Large Audio Language Models." arXiv preprint arXiv:2503.02318 (2025).  
\[2\] Ma, Ziyang, et al. "Audio-CoT: Exploring Chain-of-Thought Reasoning in Large Audio Language Model." arXiv preprint arXiv:2501.07246 (2025).  



## Inference
```python
import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor

# Load model
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

# Load example audio
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav"  # from MMAU dataset
waveform, sampling_rate = torchaudio.load(wav_path)
if sampling_rate != 16000:
    waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(waveform)
audios = [waveform[0].numpy()]

# Make prompt text
question = "Based on the given audio, identify the source of the speaking voice."
options = ["Man", "Woman", "Child", "Robot"]
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
message = [
    {"role": "user", "content": [
        {"type": "audio", "audio_url": wav_path},
        {"type": "text", "text": prompt}
    ]}
]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)

# Process
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

print(response)
```