Spaces:
Running
Running
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import os | |
| from threading import Thread | |
| import uuid | |
| import soundfile as sf | |
| import numpy as np | |
| from transformers.generation import TextIteratorStreamer | |
| # Model and Tokenizer Loading | |
| MODEL_ID = "Qwen/Qwen-Audio-Chat" | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| DESCRIPTION = "[Qwen-Audio-Chat Demo](https://huggingface.co/Qwen/Qwen-Audio-Chat)" | |
| audio_extensions = (".wav", ".mp3", ".ogg", ".flac") | |
| def process_audio(audio_path): | |
| """Process audio file and return the appropriate format for the model.""" | |
| audio_data, sample_rate = sf.read(audio_path) | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.mean(axis=1) # Convert stereo to mono if necessary | |
| return audio_data, sample_rate | |
| def qwen_inference(audio_input, text_input=None): | |
| if not isinstance(audio_input, str) or not audio_input.lower().endswith(audio_extensions): | |
| raise ValueError("Please upload a valid audio file (WAV, MP3, OGG, or FLAC)") | |
| # Process audio input | |
| audio_data, sample_rate = process_audio(audio_input) | |
| # Prepare the messages | |
| if text_input: | |
| query = text_input | |
| else: | |
| query = "Please describe what you hear in this audio clip." | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "audio", | |
| "audio": audio_input, | |
| }, | |
| { | |
| "type": "text", | |
| "text": query, | |
| }, | |
| ], | |
| } | |
| ] | |
| # Convert messages to model input format | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| # Set up streamer for real-time output | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| do_sample=True | |
| ) | |
| # Start generation in a separate thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Stream the output | |
| buffer = "" | |
| for new_text in streamer: | |
| buffer += new_text | |
| yield buffer | |
| css = """ | |
| #output { | |
| height: 500px; | |
| overflow: auto; | |
| border: 1px solid #ccc; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Tab(label="Audio Input"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_audio = gr.Audio( | |
| label="Upload Audio", | |
| type="filepath" | |
| ) | |
| text_input = gr.Textbox( | |
| label="Question (optional)", | |
| placeholder="Ask a question about the audio or leave empty for general description" | |
| ) | |
| submit_btn = gr.Button(value="Submit") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Output Text") | |
| submit_btn.click( | |
| qwen_inference, | |
| [input_audio, text_input], | |
| [output_text] | |
| ) | |
| demo.launch(debug=True) | |