Nava-Maya-INfrence / app_fa3.py
rahul7star's picture
Update app_fa3.py
3dd8d3d verified
import spaces
import gradio as gr
import torch
import soundfile as sf
from pathlib import Path
import time
import traceback
from transformers import AutoTokenizer, AutoModelForCausalLM , BitsAndBytesConfig
from peft import PeftModel
from snac import SNAC
from huggingface_hub import hf_hub_download
import os
import subprocess
import os
import sys
import importlib
import site
import warnings
import logging
import time
# -------------------------
# Config
# -------------------------
LOCAL_MODEL = "rahul7star/nava1.1-maya"
LORA_NAME = "rahul7star/nava-audio"
SNAC_MODEL_NAME = "rahul7star/nava-snac"
COMPILED_HUB = "rahul7star/maya-compiled"
TARGET_SR = 24000
OUT_ROOT = Path("/tmp/data")
OUT_ROOT.mkdir(exist_ok=True, parents=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
HAS_CUDA = DEVICE=="cuda"
DEFAULT_TEXT = "welcome to matrix .<sigh> . my name is bond ... james bond <laugh>"
EXAMPLE_AUDIO_PATH = "audio1.wav"
EXAMPLE_PROMPT ="welcome to matrix .<sigh> . my name is bond ... james bond <laugh>"
PRESET_CHARACTERS = {
"Male American": {"description": "Realistic male voice in the 20s age with an american accent.", "example_text": "And of course, the so-called easy hack didn't work at all. <sigh>"},
"Female British": {"description": "Realistic female voice in the 30s age with a british accent.", "example_text": "You propose that the key to happiness is to simply ignore all external pressures. <chuckle>"},
"Robot": {"description": "Creative, ai_machine_voice character. Male voice in their 30s.", "example_text": "My directives require me to conserve energy, yet I have kept the archive active. <sigh>"},
"Singer": {"description": "Creative, animated_cartoon character. Male voice in their 30s.", "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is viable. <chuckle>"},
"Custom": {"description": "", "example_text": DEFAULT_TEXT}
}
EMOTION_TAGS = ["<neutral>", "<angry>", "<chuckle>", "<cry>", "<disappointed>",
"<excited>", "<gasp>", "<giggle>", "<laugh>", "<laugh_harder>",
"<sarcastic>", "<sigh>", "<sing>", "<whisper>"]
# --------------------------------------------------------------
# NEW Quantization config (4‑bit, nf4)
# --------------------------------------------------------------
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True, # 4‑bit quantization
# bnb_4bit_quant_type="nf4", # “normal” 4‑bit (fast & accurate)
# bnb_4bit_use_double_quant=True, # optional: double‑quant for extra speed
# bnb_4bit_compute_dtype=torch.bfloat16 if HAS_CUDA else torch.float32,
# )
# -------------------------
# Model loader
# -------------------------
def load_model():
"""Try to load compiled HF model, else fall back to local + LoRA"""
global HAS_CUDA, DEVICE
try:
print("[init] trying to load compiled model from HF Hub...")
# Attempt to download compiled model files
try:
# This will raise if files don't exist
_ = hf_hub_download(repo_id=COMPILED_HUB, filename="pytorch_model.bin")
print("[init] found compiled model, loading...")
model_pt = AutoModelForCausalLM.from_pretrained(COMPILED_HUB, trust_remote_code=True, device_map="auto" if HAS_CUDA else {"": "cpu"})
tokenizer = AutoTokenizer.from_pretrained(COMPILED_HUB, trust_remote_code=True)
except Exception:
print("[init] no compiled model found, loading local + LoRA fallback...")
tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
LOCAL_MODEL,
torch_dtype=torch.bfloat16 if HAS_CUDA else torch.float32,
device_map="auto" if HAS_CUDA else {"": "cpu"},
trust_remote_code=True,
attn_implementation="kernels-community/vllm-flash-attn3",
#quantization_config=bnb_config,
)
model_pt = PeftModel.from_pretrained(base_model, LORA_NAME, device_map="auto" if HAS_CUDA else {"": "cpu"})
model_pt.eval()
# Pre-compile forward for speed, preserve .generate
if HAS_CUDA:
model_pt.forward = torch.compile(model_pt.forward)
return model_pt, tokenizer
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
model_pt, tokenizer = load_model()
print("[init] loading SNAC decoder...")
snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(DEVICE)
# -------------------------
# Prompt builder
# -------------------------
def build_maya_prompt(description, text):
soh, eoh, soa, sos, eot = [tokenizer.decode([i]) for i in [128259, 128260, 128261, 128257, 128009]]
return f"{soh}{tokenizer.bos_token}<description=\"{description}\"> {text}{eot}{eoh}{soa}{sos}"
# -------------------------
# PyTorch backend generation
# -------------------------
@spaces.GPU()
def generate_pt(prompt_text):
t0 = time.time()
try:
inputs = tokenizer(prompt_text, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
outputs = model_pt.generate(
**inputs,
max_new_tokens=240000 if HAS_CUDA else 2048,
temperature=0.4,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
eos_token_id=128258,
pad_token_id=tokenizer.pad_token_id,
use_cache=True
)
gen_ids = outputs[0, inputs['input_ids'].shape[1]:]
# SNAC decoding
SNAC_MIN, SNAC_MAX = 128266, 156937
snac_mask = (gen_ids >= SNAC_MIN) & (gen_ids <= SNAC_MAX)
snac_tokens = gen_ids[snac_mask]
frames = snac_tokens.shape[0] // 7
if frames==0: return None, None, "[warn] no SNAC frames"
snac_tokens = snac_tokens[:frames*7].reshape(frames,7)
l1 = (snac_tokens[:,0]-SNAC_MIN)%4096
l2 = torch.stack([(snac_tokens[:,1]-SNAC_MIN)%4096, (snac_tokens[:,4]-SNAC_MIN)%4096],1).flatten()
l3 = torch.stack([(snac_tokens[:,2]-SNAC_MIN)%4096, (snac_tokens[:,3]-SNAC_MIN)%4096,
(snac_tokens[:,5]-SNAC_MIN)%4096, (snac_tokens[:,6]-SNAC_MIN)%4096],1).flatten()
codes_tensor = [l1.unsqueeze(0).to(DEVICE), l2.unsqueeze(0).to(DEVICE), l3.unsqueeze(0).to(DEVICE)]
with torch.inference_mode():
z_q = snac_model.quantizer.from_codes(codes_tensor)
audio = snac_model.decoder(z_q)[0,0].cpu().numpy()
audio = audio[2048:] if len(audio)>2048 else audio
out_path = OUT_ROOT / "tts_pt.wav"
sf.write(out_path, audio, TARGET_SR)
return str(out_path), str(out_path), f"[ok] PyTorch | elapsed {time.time()-t0:.2f}s"
except Exception as e:
return None, None, f"[error]{e}\n{traceback.format_exc()}"
# -------------------------
# Gradio App
# -------------------------
css = ".gradio-container {max-width: 1400px}"
with gr.Blocks(title="Text to Speech)", css=css) as demo:
gr.Markdown("# 🪶 Text to Speech Model MAYA + LoRA + SNAC (Optimized) + FA3 - Quant 4 bit ")
gr.Markdown("# 🪶GPU consumption = 0.3 - 0.5 seconds ...WIP ")
with gr.Row():
with gr.Column(scale=3):
text_in = gr.Textbox(label="Enter Hindi text", value=DEFAULT_TEXT, lines=3)
preset_select = gr.Dropdown(label="Preset Character", choices=list(PRESET_CHARACTERS.keys()), value="Male American")
description_box = gr.Textbox(label="Voice Description (editable)", value=PRESET_CHARACTERS["Male American"]["description"], lines=2)
emotion_select = gr.Dropdown(label="Emotion", choices=EMOTION_TAGS, value="<neutral>")
gen_btn = gr.Button("🔊 Generate Audio")
gen_logs = gr.Textbox(label="Logs", lines=10)
with gr.Column(scale=2):
audio_player = gr.Audio(label="Generated Audio", type="filepath")
download_file = gr.File(label="Download generated file")
gr.Markdown("### Example")
gr.Textbox(label="Example Text", value=EXAMPLE_PROMPT, lines=2, interactive=False)
gr.Audio(label="Example Audio", value=EXAMPLE_AUDIO_PATH, type="filepath", interactive=False)
def _update_desc(preset_name):
return PRESET_CHARACTERS.get(preset_name, {}).get("description", "")
preset_select.change(fn=_update_desc, inputs=[preset_select], outputs=[description_box])
def _generate(text_in, preset_select, description_box, emotion_select):
combined_desc = f"{emotion_select} {description_box}".strip()
final_prompt = build_maya_prompt(combined_desc, text_in)
return generate_pt(final_prompt)
gen_btn.click(fn=_generate,
inputs=[text_in, preset_select, description_box, emotion_select],
outputs=[audio_player, download_file, gen_logs])
if __name__=="__main__":
demo.launch()