Qwen2-Audio-7B / app.py
desiree's picture
Upload 2 files
7abb7ba verified
raw
history blame
3.51 kB
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
from threading import Thread
import uuid
import soundfile as sf
import numpy as np
from transformers.generation import TextIteratorStreamer
# Model and Tokenizer Loading
MODEL_ID = "Qwen/Qwen-Audio-Chat"
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
DESCRIPTION = "[Qwen-Audio-Chat Demo](https://huggingface.co/Qwen/Qwen-Audio-Chat)"
audio_extensions = (".wav", ".mp3", ".ogg", ".flac")
def process_audio(audio_path):
"""Process audio file and return the appropriate format for the model."""
audio_data, sample_rate = sf.read(audio_path)
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1) # Convert stereo to mono if necessary
return audio_data, sample_rate
@spaces.GPU
def qwen_inference(audio_input, text_input=None):
if not isinstance(audio_input, str) or not audio_input.lower().endswith(audio_extensions):
raise ValueError("Please upload a valid audio file (WAV, MP3, OGG, or FLAC)")
# Process audio input
audio_data, sample_rate = process_audio(audio_input)
# Prepare the messages
if text_input:
query = text_input
else:
query = "Please describe what you hear in this audio clip."
messages = [
{
"role": "user",
"content": [
{
"type": "audio",
"audio": audio_input,
},
{
"type": "text",
"text": query,
},
],
}
]
# Convert messages to model input format
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Set up streamer for real-time output
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=512,
temperature=0.7,
do_sample=True
)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the output
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Audio Input"):
with gr.Row():
with gr.Column():
input_audio = gr.Audio(
label="Upload Audio",
type="filepath"
)
text_input = gr.Textbox(
label="Question (optional)",
placeholder="Ask a question about the audio or leave empty for general description"
)
submit_btn = gr.Button(value="Submit")
with gr.Column():
output_text = gr.Textbox(label="Output Text")
submit_btn.click(
qwen_inference,
[input_audio, text_input],
[output_text]
)
demo.launch(debug=True)