meralion-api / app.py
madharjan's picture
Remove INSTALL.md and test_requirements.py; clean up requirements.txt by removing torch
61a97ac
import gradio as gr
import torch
import librosa
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import os
# Global model cache
model = None
processor = None
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_model():
global model, processor
if model is None:
repo_id = "MERaLiON/MERaLiON-2-10B"
print("Loading MERaLiON-2-10B model...")
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
repo_id,
use_safetensors=True,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
)
print("Model loaded successfully!")
return model, processor
def meralion_inference(prompt, uploaded_file):
global model, processor
if uploaded_file is None:
return "Please upload an audio file."
# Load model on first run
model, processor = load_model()
try:
# Load audio at 16kHz
audio_array, sr = librosa.load(uploaded_file.name, sr=16000)
# Prompt template
prompt_template = "Instruction: {query}\nFollow the text instruction based on the following audio: <SpeechHere>"
conversation = [
{"role": "user", "content": prompt_template.format(query=prompt)}
]
chat_prompt = processor.tokenizer.apply_chat_template(
conversation=conversation, tokenize=False, add_generation_prompt=True
)
# Process inputs
inputs = processor(text=chat_prompt, audios=audio_array)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs, max_new_tokens=256, do_sample=True, temperature=0.7
)
generated_ids = outputs[:, inputs["input_ids"].size(1) :]
response = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
except Exception as e:
return f"Error during inference: {str(e)}"
with gr.Blocks() as demo:
gr.Markdown("# MERaLiON-2-10B Audio Demo")
with gr.Row():
prompt_input = gr.Textbox(
label="Enter Prompt", value="Please transcribe this speech.", lines=2
)
file_input = gr.File(
label="Upload Audio File (WAV/MP3, max 300s)",
file_types=[".wav", ".mp3", ".m4a"],
)
output_text = gr.Textbox(label="Model Output", lines=8)
submit_btn = gr.Button("Run Inference", variant="primary")
submit_btn.click(
meralion_inference, inputs=[prompt_input, file_input], outputs=output_text
)
demo.launch()