File size: 5,580 Bytes
7662fdc c1773c1 7662fdc c1773c1 7662fdc 5123f60 7662fdc 4011eae c0e9ff4 5123f60 d0ac662 7662fdc df93550 c0e9ff4 5123f60 c0e9ff4 7662fdc 26fcbc8 7662fdc b74b317 26fcbc8 7662fdc b74b317 2d1002f 8805704 b74b317 cefa606 b74b317 26fcbc8 b74b317 26fcbc8 cefa606 7662fdc b74b317 26fcbc8 b74b317 c1773c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
---
library_name: transformers
license: apache-2.0
tags: []
pipeline_tag: audio-text-to-text
---
# R1-AQA --- Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering
<!-- Provide a quick summary of what the model is/does. -->
## Introduction
R1-AQA is a audio question answering (AQA) model based on `Qwen2-Audio-7B-Instruct`, optimized through reinforcement learning using the group relative policy optimization (GRPO) algorithm.
This implementation has achieved state-of-the-art performance on MMAU *Test-mini* benchmark with only 38k post-training samples.
For more details, please refer to our [Github](https://github.com/xiaomi-research/r1-aqa) and [Technical Report](https://arxiv.org/abs/2503.11197).
Our main findings are as follows:
- The GRPO algorithm can be directly and effectively applied to the audio modality, even to `Qwen2-Audio-7B-Instruct` with only 8.2B parameters.
- With only 38k post-training samples, reinforcement learning outperforms supervised fine-tuning, indicating that RL-based approaches can be effective without large datasets.
- The explicit reasoning process has not shown significant benefits for AQA tasks, and how to efficiently leverage *deep thinking* or step-by-step reasoning remains an open question for further research.
- Large audio language models (LALMs) still lag far behind humans auditory-language reasoning, suggesting that the RL-based approaches warrant further explorations.
Additional Notes:
The AVQA training set originally consists of approximately 40k samples. However, we use only about 38k samples because some data sources have become invalid. Other datasets using YouTube sources face a similar issue, such as AudioSet. We believe that the missing 2k samples do not have a significant impact on the training results.
### Table: Accuracies (%) on MMAU Test-mini benchmark
| Model | Method | Sound | Music | Speech | Average |
|--------------------------------------------|-------------------------|--------|--------|--------|---------|
| \ | Human\* | 86.31 | 78.22 | 82.17 | 82.23 |
| Gemini Pro 2.0 Flash | Direct Inference\* | 56.46 | 58.68 | 51.65 | 55.60 |
| Audio Flamingo 2 | Direct Inference\* | 61.56 | **73.95** | 30.93 | 55.48 |
| GPT4o + Strong Cap. | Direct Inference\* | 57.35 | 49.70 | **64.86** | 57.30 |
| Llama-3-8B-Instruct + Strong Cap. | Direct Inference\* | 50.75 | 48.93 | 55.25 | 52.10 |
| Gemini Pro v1.5 | Direct Inference\* | 56.75 | 49.40 | 58.55 | 54.90 |
| Qwen2-Audio-7B-Instruct | Direct Inference\* | 54.95 | 50.98 | 42.04 | 49.20 |
| GPT4o + Weak Cap. | Direct Inference\* | 39.33 | 41.90 | 58.25 | 45.70 |
| Llama-3-8B-Instruct + Weak Cap. | Direct Inference\* | 34.23 | 38.02 | 54.05 | 42.10 |
| SALMONN | Direct Inference\* | 41.00 | 34.80 | 25.50 | 33.70 |
| Qwen2-Audio-7B-Instruct | CoTA \[1\] | 60.06 | 64.30 | 60.70 | 61.71 |
| Qwen2-Audio-7B-Instruct | Zero-Shot-CoT \[2\] | 61.86 | 56.29 | 55.26 | 57.80 |
| **Qwen2-Audio-7B-Instruct** | **GRPO (Ours)** | **69.37** | 66.77 | 57.36 | **64.50** |
#### Notes:
\* The data are sourced from the MMAU official website: [https://sakshi113.github.io/mmau_homepage/](https://sakshi113.github.io/mmau_homepage/)
\[1\] Xie, Zhifei, et al. "Audio-Reasoner: Improving Reasoning Capability in Large Audio Language Models." arXiv preprint arXiv:2503.02318 (2025).
\[2\] Ma, Ziyang, et al. "Audio-CoT: Exploring Chain-of-Thought Reasoning in Large Audio Language Model." arXiv preprint arXiv:2501.07246 (2025).
## Inference
```python
import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
# Load model
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
# Load example audio
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" # from MMAU dataset
waveform, sampling_rate = torchaudio.load(wav_path)
if sampling_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(waveform)
audios = [waveform[0].numpy()]
# Make prompt text
question = "Based on the given audio, identify the source of the speaking voice."
options = ["Man", "Woman", "Child", "Robot"]
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
message = [
{"role": "user", "content": [
{"type": "audio", "audio_url": wav_path},
{"type": "text", "text": prompt}
]}
]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
# Process
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(response)
``` |