Spaces:
Running
on
T4
Running
on
T4
File size: 7,246 Bytes
9c20b4e ab25593 9c20b4e ab25593 9c20b4e ab25593 b9d657b ab25593 9c20b4e 8ddd281 ab25593 9c20b4e ab25593 9c20b4e ab25593 8ddd281 ab25593 8ddd281 9c20b4e ab25593 9c20b4e ab25593 9c20b4e ab25593 9c20b4e ab25593 9c20b4e dc06293 ab25593 dc06293 8a1ab06 ab25593 8fa13bc ab25593 8ddd281 ab25593 2cf7afe ab25593 2cf7afe ab25593 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
"""
utils.py
Functions:
- generate_script: Get the dialogue from the LLM.
- call_llm: Call the LLM with the given prompt and dialogue format.
- parse_url: Parse the given URL and return the text content.
- generate_podcast_audio: Generate audio for podcast using TTS or advanced audio models.
"""
# Standard library imports
import time
from typing import Any, Union
# Third-party imports
import requests
from bark import SAMPLE_RATE, generate_audio, preload_models
from gradio_client import Client
from openai import OpenAI
from pydantic import ValidationError
from scipy.io.wavfile import write as write_wav
# Local imports
from constants import (
FIREWORKS_API_KEY,
FIREWORKS_BASE_URL,
FIREWORKS_MODEL_ID,
FIREWORKS_MAX_TOKENS,
FIREWORKS_TEMPERATURE,
FIREWORKS_JSON_RETRY_ATTEMPTS,
MELO_API_NAME,
MELO_TTS_SPACES_ID,
MELO_RETRY_ATTEMPTS,
MELO_RETRY_DELAY,
JINA_READER_URL,
JINA_RETRY_ATTEMPTS,
JINA_RETRY_DELAY,
)
from schema import ShortDialogue, MediumDialogue
# Initialize clients
fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
hf_client = Client(MELO_TTS_SPACES_ID)
# Download and load all models for Bark
preload_models()
def generate_script(
system_prompt: str,
input_text: str,
output_model: Union[ShortDialogue, MediumDialogue],
) -> Union[ShortDialogue, MediumDialogue]:
"""Get the dialogue from the LLM."""
# Call the LLM
response = call_llm(system_prompt, input_text, output_model)
response_json = response.choices[0].message.content
# Validate the response
for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
try:
first_draft_dialogue = output_model.model_validate_json(response_json)
break
except ValidationError as e:
if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
raise ValueError(
f"Failed to parse dialogue JSON after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
) from e
error_message = (
f"Failed to parse dialogue JSON (attempt {attempt + 1}): {e}"
)
# Re-call the LLM with the error message
system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
response = call_llm(system_prompt_with_error, input_text, output_model)
response_json = response.choices[0].message.content
first_draft_dialogue = output_model.model_validate_json(response_json)
# Call the LLM a second time to improve the dialogue
system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue}."
# Validate the response
for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
try:
response = call_llm(
system_prompt_with_dialogue,
"Please improve the dialogue. Make it more natural and engaging.",
output_model,
)
final_dialogue = output_model.model_validate_json(
response.choices[0].message.content
)
break
except ValidationError as e:
if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
raise ValueError(
f"Failed to improve dialogue after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
) from e
error_message = f"Failed to improve dialogue (attempt {attempt + 1}): {e}"
system_prompt_with_dialogue += f"\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
return final_dialogue
def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any:
"""Call the LLM with the given prompt and dialogue format."""
response = fw_client.chat.completions.create(
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
],
model=FIREWORKS_MODEL_ID,
max_tokens=FIREWORKS_MAX_TOKENS,
temperature=FIREWORKS_TEMPERATURE,
response_format={
"type": "json_object",
"schema": dialogue_format.model_json_schema(),
},
)
return response
def parse_url(url: str) -> str:
"""Parse the given URL and return the text content."""
for attempt in range(JINA_RETRY_ATTEMPTS):
try:
full_url = f"{JINA_READER_URL}{url}"
response = requests.get(full_url, timeout=60)
response.raise_for_status() # Raise an exception for bad status codes
break
except requests.RequestException as e:
if attempt == JINA_RETRY_ATTEMPTS - 1: # Last attempt
raise ValueError(
f"Failed to fetch URL after {JINA_RETRY_ATTEMPTS} attempts: {e}"
) from e
time.sleep(JINA_RETRY_DELAY) # Wait for X second before retrying
return response.text
def generate_podcast_audio(
text: str, speaker: str, language: str, use_advanced_audio: bool, random_voice_number: int
) -> str:
"""Generate audio for podcast using TTS or advanced audio models."""
if use_advanced_audio:
return _use_suno_model(text, speaker, language, random_voice_number)
else:
return _use_melotts_api(text, speaker, language)
def _use_suno_model(text: str, speaker: str, language: str, random_voice_number: int) -> str:
"""Generate advanced audio using Bark."""
host_voice_num = str(random_voice_number)
guest_voice_num = str(random_voice_number + 1)
audio_array = generate_audio(
text,
history_prompt=f"v2/{language}_speaker_{host_voice_num if speaker == 'Host (Jane)' else guest_voice_num}",
)
file_path = f"audio_{language}_{speaker}.mp3"
write_wav(file_path, SAMPLE_RATE, audio_array)
return file_path
def _use_melotts_api(text: str, speaker: str, language: str) -> str:
"""Generate audio using TTS model."""
accent, speed = _get_melo_tts_params(speaker, language)
for attempt in range(MELO_RETRY_ATTEMPTS):
try:
return hf_client.predict(
text=text,
language=language,
speaker=accent,
speed=speed,
api_name=MELO_API_NAME,
)
except Exception as e:
if attempt == MELO_RETRY_ATTEMPTS - 1: # Last attempt
raise # Re-raise the last exception if all attempts fail
time.sleep(MELO_RETRY_DELAY) # Wait for X second before retrying
def _get_melo_tts_params(speaker: str, language: str) -> tuple[str, float]:
"""Get TTS parameters based on speaker and language."""
if speaker == "Guest":
accent = "EN-US" if language == "EN" else language
speed = 0.9
else: # host
accent = "EN-Default" if language == "EN" else language
speed = (
1.1 if language != "EN" else 1
) # if the language is not English, try speeding up so it'll sound different from the host
# for non-English, there is only one voice
return accent, speed
|