Spaces:
Sleeping
Sleeping
| # app_optimized_comparison.py | |
| """ | |
| Optimized inference for Maya1 + LoRA + SNAC. | |
| Includes side-by-side Base vs LoRA comparison for audio. | |
| """ | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import soundfile as sf | |
| from pathlib import Path | |
| import traceback | |
| import time | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from snac import SNAC | |
| # ------------------------- | |
| # Config / constants | |
| # ------------------------- | |
| MODEL_NAME = "rahul7star/nava1.0" | |
| LORA_NAME = "rahul7star/nava-audio" | |
| SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz" | |
| TARGET_SR = 24000 | |
| OUT_ROOT = Path("/tmp/data") | |
| OUT_ROOT.mkdir(exist_ok=True, parents=True) | |
| DEFAULT_TEXT = "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी" | |
| EXAMPLE_AUDIO_PATH = "audio.wav" | |
| PRESET_CHARACTERS = { | |
| "Male American": { | |
| "description": "Realistic male voice in the 20s age with an american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery", | |
| "example_text": "And of course, the so-called easy hack didn't work at all. What a surprise. <sigh>" | |
| }, | |
| "Female British": { | |
| "description": "Realistic female voice in the 30s age with a british accent. Normal pitch, throaty timbre, conversational pacing, sarcastic tone delivery at low intensity, podcast domain, interviewer role, formal delivery", | |
| "example_text": "You propose that the key to happiness is to simply ignore all external pressures. <chuckle> I'm sure it must work brilliantly in theory." | |
| }, | |
| "Robot": { | |
| "description": "Creative, ai_machine_voice character. Male voice in their 30s with an american accent. High pitch, robotic timbre, slow pacing, sad tone at medium intensity.", | |
| "example_text": "My directives require me to conserve energy, yet I have kept the archive of their farewell messages active. <sigh>" | |
| }, | |
| "Singer": { | |
| "description": "Creative, animated_cartoon character. Male voice in their 30s with an american accent. High pitch, deep timbre, slow pacing, sarcastic tone at medium intensity.", | |
| "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is a viable course of action. <chuckle> Why would we ever consider running away very fast." | |
| }, | |
| "Custom": { | |
| "description": "", | |
| "example_text": DEFAULT_TEXT | |
| } | |
| } | |
| EMOTION_TAGS = [ | |
| "<neutral>", "<angry>", "<chuckle>", "<cry>", "<disappointed>", | |
| "<excited>", "<gasp>", "<giggle>", "<laugh>", "<laugh_harder>", | |
| "<sarcastic>", "<sigh>", "<sing>", "<whisper>" | |
| ] | |
| SEQ_LEN_CPU = 4096 | |
| MAX_NEW_TOKENS_CPU = 1024 | |
| SEQ_LEN_GPU = 240000 | |
| MAX_NEW_TOKENS_GPU = 240000 | |
| HAS_CUDA = torch.cuda.is_available() | |
| DEVICE = "cuda" if HAS_CUDA else "cpu" | |
| # ------------------------- | |
| # Load tokenizer and models | |
| # ------------------------- | |
| print("[init] loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| # precompute special tokens | |
| SOH = tokenizer.decode([128259]) | |
| EOH = tokenizer.decode([128260]) | |
| SOA = tokenizer.decode([128261]) | |
| SOS = tokenizer.decode([128257]) | |
| EOT = tokenizer.decode([128009]) | |
| BOS = tokenizer.bos_token | |
| # Base model (no LoRA) + LoRA model | |
| print("[init] loading base model (CPU/GPU)...") | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float32, | |
| device_map={"": "cpu"} if not HAS_CUDA else "auto", | |
| trust_remote_code=True | |
| ) | |
| base_model.eval() | |
| model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map={"": "cpu"} if not HAS_CUDA else "auto") | |
| model.eval() | |
| # ------------------------- | |
| # Load SNAC decoder | |
| # ------------------------- | |
| snac_device = DEVICE if HAS_CUDA else "cpu" | |
| snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(snac_device) | |
| # ------------------------- | |
| # SNAC utils | |
| # ------------------------- | |
| CODE_END_TOKEN_ID = 128258 | |
| CODE_TOKEN_OFFSET = 128266 | |
| SNAC_MIN_ID = 128266 | |
| SNAC_MAX_ID = 156937 | |
| SNAC_TOKENS_PER_FRAME = 7 | |
| def extract_snac_codes(token_ids: list) -> list: | |
| try: | |
| eos_idx = token_ids.index(CODE_END_TOKEN_ID) | |
| except ValueError: | |
| eos_idx = len(token_ids) | |
| return [t for t in token_ids[:eos_idx] if SNAC_MIN_ID <= t <= SNAC_MAX_ID] | |
| def unpack_snac_from_7(snac_tokens: list) -> list: | |
| frames = len(snac_tokens) // SNAC_TOKENS_PER_FRAME | |
| snac_tokens = snac_tokens[:frames*SNAC_TOKENS_PER_FRAME] | |
| if frames == 0: | |
| return [[], [], []] | |
| l1, l2, l3 = [], [], [] | |
| for i in range(frames): | |
| slots = snac_tokens[i*7:(i+1)*7] | |
| l1.append((slots[0]-SNAC_MIN_ID)%4096) | |
| l2.extend([(slots[1]-SNAC_MIN_ID)%4096, (slots[4]-SNAC_MIN_ID)%4096]) | |
| l3.extend([(slots[2]-SNAC_MIN_ID)%4096, (slots[3]-SNAC_MIN_ID)%4096, (slots[5]-SNAC_MIN_ID)%4096, (slots[6]-SNAC_MIN_ID)%4096]) | |
| return [l1, l2, l3] | |
| # ------------------------- | |
| # Prompt builder | |
| # ------------------------- | |
| def build_maya_prompt(description: str, text: str): | |
| return SOH + BOS + f'<description="{description}"> {text}' + EOT + EOH + SOA + SOS | |
| # ------------------------- | |
| # Optimized generator | |
| # ------------------------- | |
| def generate_audio_from_model(model_to_use, description, text, fname="tts.wav"): | |
| logs = [] | |
| t0 = time.time() | |
| try: | |
| prompt = build_maya_prompt(description, text) | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE) | |
| max_new = min(MAX_NEW_TOKENS_CPU, 1024) if DEVICE=="cpu" else MAX_NEW_TOKENS_GPU | |
| with torch.inference_mode(): | |
| outputs = model_to_use.generate( | |
| **inputs, | |
| max_new_tokens=max_new, | |
| temperature=0.4, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| eos_token_id=128258, | |
| pad_token_id=tokenizer.pad_token_id, | |
| use_cache=True | |
| ) | |
| gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist() | |
| logs.append(f"[info] tokens generated: {len(gen_ids)}") | |
| snac_tokens = extract_snac_codes(gen_ids) | |
| levels = unpack_snac_from_7(snac_tokens) | |
| codes_tensor = [torch.tensor(l, dtype=torch.long, device=snac_device).unsqueeze(0) for l in levels] | |
| with torch.inference_mode(): | |
| z_q = snac_model.quantizer.from_codes(codes_tensor) | |
| audio = snac_model.decoder(z_q)[0,0].cpu().numpy() | |
| if len(audio) > 2048: | |
| audio = audio[2048:] | |
| out_path = OUT_ROOT / fname | |
| sf.write(out_path, audio, TARGET_SR) | |
| logs.append(f"[ok] saved {out_path}, duration {len(audio)/TARGET_SR:.2f}s") | |
| logs.append(f"[time] elapsed {time.time()-t0:.2f}s") | |
| return str(out_path), "\n".join(logs) | |
| except Exception as e: | |
| logs.append(f"[error] {e}\n{traceback.format_exc()}") | |
| return None, "\n".join(logs) | |
| # ------------------------- | |
| # Gradio UI | |
| # ------------------------- | |
| css = """ | |
| .gradio-container {max-width: 1400px} | |
| .example-box { | |
| border: 1px solid #ccc; | |
| padding: 12px; | |
| border-radius: 8px; | |
| background: #f8f8f8; | |
| } | |
| .video_box video { | |
| width: 260px !important; | |
| height: 160px !important; | |
| object-fit: cover; | |
| } | |
| """ | |
| with gr.Blocks(title="NAVA — VEEN + LoRA + SNAC (Optimized)", css=css) as demo: | |
| gr.Markdown("# 🪶 NAVA — VEEN + LoRA + SNAC (Optimized)") | |
| gr.Markdown("Generate emotional Hindi speech using Maya1 base + your LoRA adapter.") | |
| with gr.Row(): | |
| # ---------------- LEFT SIDE ---------------- | |
| with gr.Column(scale=3): | |
| gr.Markdown("## 🎤 Inference (CPU/GPU auto)") | |
| text_in = gr.Textbox(label="Enter Hindi text", value=DEFAULT_TEXT, lines=3) | |
| preset_select = gr.Dropdown( | |
| label="Select Preset Character", | |
| choices=list(PRESET_CHARACTERS.keys()), | |
| value="Male American" | |
| ) | |
| description_box = gr.Textbox( | |
| label="Voice Description (editable)", | |
| value=PRESET_CHARACTERS["Male American"]["description"], | |
| lines=2 | |
| ) | |
| emotion_select = gr.Dropdown( | |
| label="Select Emotion", | |
| choices=EMOTION_TAGS, | |
| value="<neutral>" | |
| ) | |
| gen_btn = gr.Button("🔊 Generate Audio (Base + LoRA)") | |
| gen_logs = gr.Textbox(label="Logs", lines=10) | |
| # ---------------- EXAMPLES ---------------- | |
| gr.Markdown("## 📎 Example") | |
| with gr.Column(elem_classes=["example-box"]): | |
| example_text = DEFAULT_TEXT | |
| example_audio_path = "audio.wav" | |
| example_video = "gen_31ff9f64b1.mp4" | |
| gr.Textbox( | |
| label="Example Text", | |
| value=example_text, | |
| lines=2, | |
| interactive=False | |
| ) | |
| gr.Audio( | |
| label="Example Audio", | |
| value=example_audio_path, | |
| type="filepath", | |
| interactive=False | |
| ) | |
| gr.Video( | |
| label="Example Video", | |
| value=example_video, | |
| autoplay=False, | |
| loop=False, | |
| interactive=False, | |
| elem_classes=["video_box"] | |
| ) | |
| # ---------------- RIGHT SIDE ---------------- | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 🎧 Audio Results Comparison") | |
| audio_output_base = gr.Audio(label="Base Model Audio", type="filepath") | |
| audio_output_lora = gr.Audio(label="LoRA Model Audio", type="filepath") | |
| # ---------------- PRESET UPDATE ---------------- | |
| def _update_desc(preset_name): | |
| return PRESET_CHARACTERS.get(preset_name, {}).get("description", "") | |
| preset_select.change( | |
| fn=_update_desc, | |
| inputs=[preset_select], | |
| outputs=[description_box] | |
| ) | |
| # ---------------- GENERATION HANDLER ---------------- | |
| def _generate(text, preset_name, description, emotion): | |
| desc = description or PRESET_CHARACTERS.get(preset_name, {}).get("description", "") | |
| combined = f"{emotion} {desc}".strip() | |
| base_path, log_base = generate_audio_from_model( | |
| base_model, combined, text, fname="tts_base.wav" | |
| ) | |
| lora_path, log_lora = generate_audio_from_model( | |
| model, combined, text, fname="tts_lora.wav" | |
| ) | |
| logs = f"[Base]\n{log_base}\n\n[LoRA]\n{log_lora}" | |
| return base_path, lora_path, logs | |
| gen_btn.click( | |
| fn=_generate, | |
| inputs=[text_in, preset_select, description_box, emotion_select], | |
| outputs=[audio_output_base, audio_output_lora, gen_logs] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |