Audio-Text-to-Text
Transformers
Safetensors
qwen2_audio
text2text-generation
Inference Endpoints
r1-aqa / README.md
frankenliu's picture
Update README.md (#7)
de2b10f verified
metadata
library_name: transformers
license: apache-2.0
tags: []
pipeline_tag: audio-text-to-text

R1-AQA --- Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering

Introduction

R1-AQA is a audio question answering (AQA) model based on Qwen2-Audio-7B-Instruct, optimized through reinforcement learning using the group relative policy optimization (GRPO) algorithm. This implementation has achieved state-of-the-art performance on MMAU Test-mini benchmark with only 38k post-training samples. For more details, please refer to our Github and Technical Report.

Our main findings are as follows:

  • The GRPO algorithm can be directly and effectively applied to the audio modality, even to Qwen2-Audio-7B-Instruct with only 8.2B parameters.
  • With only 38k post-training samples, reinforcement learning outperforms supervised fine-tuning, indicating that RL-based approaches can be effective without large datasets.
  • The explicit reasoning process has not shown significant benefits for AQA tasks, and how to efficiently leverage deep thinking or step-by-step reasoning remains an open question for further research.
  • Large audio language models (LALMs) still lag far behind humans auditory-language reasoning, suggesting that the RL-based approaches warrant further explorations.

Additional Notes:

  • The AVQA training set originally consists of approximately 40k samples. However, we use only about 38k samples because some data sources have become invalid. Other datasets using YouTube sources face a similar issue, such as AudioSet. We believe that the missing 2k samples do not have a significant impact on the training results.
  • The statement about the 8.2B parameters is based on the Qwen2-Audio Technical Report.

Table: Accuracies (%) on MMAU Test-mini benchmark

Model Method Sound Music Speech Average
\ Human* 86.31 78.22 82.17 82.23
Gemini Pro 2.0 Flash Direct Inference* 56.46 58.68 51.65 55.60
Audio Flamingo 2 Direct Inference* 61.56 73.95 30.93 55.48
GPT4o + Strong Cap. Direct Inference* 57.35 49.70 64.86 57.30
Llama-3-8B-Instruct + Strong Cap. Direct Inference* 50.75 48.93 55.25 52.10
Gemini Pro v1.5 Direct Inference* 56.75 49.40 58.55 54.90
Qwen2-Audio-7B-Instruct Direct Inference* 54.95 50.98 42.04 49.20
GPT4o + Weak Cap. Direct Inference* 39.33 41.90 58.25 45.70
Llama-3-8B-Instruct + Weak Cap. Direct Inference* 34.23 38.02 54.05 42.10
SALMONN Direct Inference* 41.00 34.80 25.50 33.70
Qwen2-Audio-7B-Instruct CoTA [1] 60.06 64.30 60.70 61.71
Qwen2-Audio-7B-Instruct Zero-Shot-CoT [2] 61.86 56.29 55.26 57.80
Qwen2-Audio-7B-Instruct GRPO (Ours) 69.37 66.77 57.36 64.50

Notes

* The data are sourced from the MMAU official website: https://sakshi113.github.io/mmau_homepage/
[1] Xie, Zhifei, et al. "Audio-Reasoner: Improving Reasoning Capability in Large Audio Language Models." arXiv preprint arXiv:2503.02318 (2025).
[2] Ma, Ziyang, et al. "Audio-CoT: Exploring Chain-of-Thought Reasoning in Large Audio Language Model." arXiv preprint arXiv:2501.07246 (2025).

Inference

import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor

# Load model
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

# Load example audio
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav"  # from MMAU dataset
waveform, sampling_rate = torchaudio.load(wav_path)
if sampling_rate != 16000:
    waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(waveform)
audios = [waveform[0].numpy()]

# Make prompt text
question = "Based on the given audio, identify the source of the speaking voice."
options = ["Man", "Woman", "Child", "Robot"]
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
message = [
    {"role": "user", "content": [
        {"type": "audio", "audio_url": wav_path},
        {"type": "text", "text": prompt}
    ]}
]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)

# Process
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

print(response)

Citation

@article{li2025reinforcement,
  title={Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering},
  author={Li, Gang and Liu, Jizhong and Dinkel, Heinrich and Niu, Yadong and Zhang, Junbo and Luan, Jian},
  journal={arXiv preprint arXiv:2503.11197},
  year={2025},
  url={https://github.com/xiaomi-research/r1-aqa; https://huggingface.co/mispeech/r1-aqa}
}