Spaces:
Running
on
T4
Running
on
T4
File size: 4,247 Bytes
9c20b4e 8a1ab06 dc06293 8a1ab06 9c20b4e 8a1ab06 17dad6c 8a1ab06 9c20b4e 8a1ab06 9c20b4e 8a1ab06 9c20b4e d336953 dc06293 d336953 9c20b4e 8a1ab06 9c20b4e 9db5d78 9c20b4e 9db5d78 9c20b4e 9db5d78 9c20b4e d336953 9c20b4e dc06293 8a1ab06 dc06293 8a1ab06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
"""
utils.py
Functions:
- get_script: Get the dialogue from the LLM.
- call_llm: Call the LLM with the given prompt and dialogue format.
- get_audio: Get the audio from the TTS model from HF Spaces.
"""
import os
import requests
import tempfile
import soundfile as sf
import spaces
import torch
from gradio_client import Client
from openai import OpenAI
from parler_tts import ParlerTTSForConditionalGeneration
from pydantic import ValidationError
from transformers import AutoTokenizer
MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
JINA_URL = "https://r.jina.ai/"
client = OpenAI(
base_url="https://api.fireworks.ai/inference/v1",
api_key=os.getenv("FIREWORKS_API_KEY"),
)
hf_client = Client("mrfakename/MeloTTS")
# Initialize the model and tokenizer (do this outside the function for efficiency)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
def generate_script(system_prompt: str, input_text: str, output_model):
"""Get the dialogue from the LLM."""
# Load as python object
try:
response = call_llm(system_prompt, input_text, output_model)
dialogue = output_model.model_validate_json(
response.choices[0].message.content
)
except ValidationError as e:
error_message = f"Failed to parse dialogue JSON: {e}"
system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
response = call_llm(system_prompt_with_error, input_text, output_model)
dialogue = output_model.model_validate_json(
response.choices[0].message.content
)
return dialogue
def call_llm(system_prompt: str, text: str, dialogue_format):
"""Call the LLM with the given prompt and dialogue format."""
response = client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
model=MODEL_ID,
max_tokens=16_384,
temperature=0.1,
response_format={
"type": "json_object",
"schema": dialogue_format.model_json_schema(),
},
)
return response
def parse_url(url: str) -> str:
"""Parse the given URL and return the text content."""
full_url = f"{JINA_URL}{url}"
response = requests.get(full_url, timeout=60)
return response.text
def generate_audio(text: str, speaker: str, language: str, voice: str) -> str:
"""Generate audio using the local Parler TTS model or HuggingFace client."""
if language == "EN":
# Adjust the description based on speaker and language
if speaker == "Guest":
description = f"{voice} has a slightly expressive and animated speech, speaking at a moderate speed with natural pitch variations. The voice is clear and close-up, as if recorded in a professional studio."
else: # host
description = f"{voice} has a professional and engaging tone, speaking at a moderate to slightly faster pace. The voice is clear, warm, and sounds like a seasoned podcast host."
# Prepare inputs
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
# Generate audio
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
# Save to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_file:
sf.write(temp_file.name, audio_arr, model.config.sampling_rate, format='mp3')
return temp_file.name
else:
accent = language
if speaker == "Guest":
speed = 0.9
else: # host
speed = 1.1
# Generate audio
result = hf_client.predict(
text=text, language=language, speaker=accent, speed=speed, api_name="/synthesize"
)
return result
|