Spaces:
Running
on
T4
Running
on
T4
gabriel chua
commited on
Chore: Clean up code (#2)
Browse files- app.py +87 -151
- constants.py +166 -0
- prompts.py +13 -0
- schema.py +34 -0
- utils.py +148 -77
app.py
CHANGED
@@ -8,69 +8,48 @@ import os
|
|
8 |
import time
|
9 |
from pathlib import Path
|
10 |
from tempfile import NamedTemporaryFile
|
11 |
-
from typing import List,
|
12 |
|
13 |
# Third-party imports
|
14 |
import gradio as gr
|
|
|
15 |
from loguru import logger
|
16 |
-
from pydantic import BaseModel, Field
|
17 |
from pypdf import PdfReader
|
18 |
from pydub import AudioSegment
|
19 |
|
20 |
# Local imports
|
21 |
-
from
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
"Hindi": "hi",
|
54 |
-
"Italian": "it",
|
55 |
-
"Japanese": "ja",
|
56 |
-
"Korean": "ko",
|
57 |
-
"Polish": "pl",
|
58 |
-
"Portuguese": "pt",
|
59 |
-
"Russian": "ru",
|
60 |
-
"Spanish": "es",
|
61 |
-
"Turkish": "tr"
|
62 |
-
}
|
63 |
-
|
64 |
-
MELO_TTS_LANGUAGE_MAPPING = {
|
65 |
-
"en": "EN",
|
66 |
-
"es": "ES",
|
67 |
-
"fr": "FR",
|
68 |
-
"zh": "ZJ",
|
69 |
-
"ja": "JP",
|
70 |
-
"ko": "KR",
|
71 |
-
}
|
72 |
-
|
73 |
-
|
74 |
|
75 |
|
76 |
def generate_podcast(
|
@@ -84,32 +63,30 @@ def generate_podcast(
|
|
84 |
) -> Tuple[str, str]:
|
85 |
"""Generate the audio and transcript from the PDFs and/or URL."""
|
86 |
|
87 |
-
|
88 |
-
|
89 |
text = ""
|
90 |
|
91 |
-
#
|
92 |
-
|
93 |
-
|
|
|
|
|
94 |
|
95 |
# Check if at least one input is provided
|
96 |
if not files and not url:
|
97 |
-
raise gr.Error(
|
98 |
|
99 |
# Process PDFs if any
|
100 |
if files:
|
101 |
for file in files:
|
102 |
if not file.lower().endswith(".pdf"):
|
103 |
-
raise gr.Error(
|
104 |
-
f"File {file} is not a PDF. Please upload only PDF files."
|
105 |
-
)
|
106 |
|
107 |
try:
|
108 |
with Path(file).open("rb") as f:
|
109 |
reader = PdfReader(f)
|
110 |
text += "\n\n".join([page.extract_text() for page in reader.pages])
|
111 |
except Exception as e:
|
112 |
-
raise gr.Error(f"
|
113 |
|
114 |
# Process URL if provided
|
115 |
if url:
|
@@ -120,34 +97,27 @@ def generate_podcast(
|
|
120 |
raise gr.Error(str(e))
|
121 |
|
122 |
# Check total character count
|
123 |
-
if len(text) >
|
124 |
-
raise gr.Error(
|
125 |
-
"The total content is too long. Please ensure the combined text from PDFs and URL is fewer than ~100,000 characters."
|
126 |
-
)
|
127 |
-
|
128 |
|
129 |
# Modify the system prompt based on the user input
|
130 |
modified_system_prompt = SYSTEM_PROMPT
|
|
|
131 |
if question:
|
132 |
-
modified_system_prompt += f"\n\
|
133 |
if tone:
|
134 |
-
modified_system_prompt += f"\n\
|
135 |
if length:
|
136 |
-
|
137 |
-
"Short (1-2 min)": "Keep the podcast brief, around 1-2 minutes long.",
|
138 |
-
"Medium (3-5 min)": "Aim for a moderate length, about 3-5 minutes.",
|
139 |
-
}
|
140 |
-
modified_system_prompt += f"\n\nLENGTH: {length_instructions[length]}"
|
141 |
if language:
|
142 |
-
modified_system_prompt +=
|
143 |
-
f"\n\nOUTPUT LANGUAGE <IMPORTANT>: The the podcast should be {language}."
|
144 |
-
)
|
145 |
|
146 |
# Call the LLM
|
147 |
if length == "Short (1-2 min)":
|
148 |
llm_output = generate_script(modified_system_prompt, text, ShortDialogue)
|
149 |
else:
|
150 |
llm_output = generate_script(modified_system_prompt, text, MediumDialogue)
|
|
|
151 |
logger.info(f"Generated dialogue: {llm_output}")
|
152 |
|
153 |
# Process the dialogue
|
@@ -164,14 +134,14 @@ def generate_podcast(
|
|
164 |
transcript += speaker + "\n\n"
|
165 |
total_characters += len(line.text)
|
166 |
|
167 |
-
language_for_tts =
|
168 |
|
169 |
if not use_advanced_audio:
|
170 |
language_for_tts = MELO_TTS_LANGUAGE_MAPPING[language_for_tts]
|
171 |
|
172 |
# Get audio file path
|
173 |
audio_file_path = generate_podcast_audio(
|
174 |
-
line.text, line.speaker, language_for_tts, use_advanced_audio
|
175 |
)
|
176 |
# Read the audio file into an AudioSegment
|
177 |
audio_segment = AudioSegment.from_file(audio_file_path)
|
@@ -181,7 +151,7 @@ def generate_podcast(
|
|
181 |
combined_audio = sum(audio_segments)
|
182 |
|
183 |
# Export the combined audio to a temporary file
|
184 |
-
temporary_directory =
|
185 |
os.makedirs(temporary_directory, exist_ok=True)
|
186 |
|
187 |
temporary_file = NamedTemporaryFile(
|
@@ -193,7 +163,10 @@ def generate_podcast(
|
|
193 |
|
194 |
# Delete any files in the temp directory that end with .mp3 and are over a day old
|
195 |
for file in glob.glob(f"{temporary_directory}*.mp3"):
|
196 |
-
if
|
|
|
|
|
|
|
197 |
os.remove(file)
|
198 |
|
199 |
logger.info(f"Generated {total_characters} characters of audio")
|
@@ -202,90 +175,53 @@ def generate_podcast(
|
|
202 |
|
203 |
|
204 |
demo = gr.Interface(
|
205 |
-
title=
|
206 |
-
description=
|
207 |
-
|
208 |
-
<table style="border-collapse: collapse; border: none; padding: 20px;">
|
209 |
-
<tr style="border: none;">
|
210 |
-
<td style="border: none; vertical-align: top; padding-right: 30px; padding-left: 30px;">
|
211 |
-
<img src="https://raw.githubusercontent.com/gabrielchua/daily-ai-papers/main/_includes/icon.png" alt="Open NotebookLM" width="120" style="margin-bottom: 10px;">
|
212 |
-
</td>
|
213 |
-
<td style="border: none; vertical-align: top; padding: 10px;">
|
214 |
-
<p style="margin-bottom: 15px;"><strong>Convert</strong> your PDFs into podcasts with open-source AI models (Llama 3.1 405B and MeloTTS).</p>
|
215 |
-
<p style="margin-top: 15px;">Note: Only the text content of the PDFs will be processed. Images and tables are not included. The total content should be no more than 100,000 characters due to the context length of Llama 3.1 405B.</p>
|
216 |
-
</td>
|
217 |
-
</tr>
|
218 |
-
</table>
|
219 |
-
""",
|
220 |
fn=generate_podcast,
|
221 |
inputs=[
|
222 |
gr.File(
|
223 |
-
label="
|
|
|
|
|
224 |
),
|
225 |
gr.Textbox(
|
226 |
-
label="
|
227 |
-
placeholder="
|
228 |
),
|
229 |
-
gr.Textbox(label="
|
230 |
gr.Dropdown(
|
231 |
-
|
232 |
-
|
233 |
-
value="
|
234 |
),
|
235 |
gr.Dropdown(
|
236 |
-
|
237 |
-
|
238 |
-
value="
|
239 |
),
|
240 |
gr.Dropdown(
|
241 |
-
choices=
|
242 |
-
value="
|
243 |
-
label="
|
244 |
),
|
245 |
gr.Checkbox(
|
246 |
-
label="
|
247 |
-
value=
|
248 |
-
)
|
249 |
],
|
250 |
outputs=[
|
251 |
-
gr.Audio(
|
252 |
-
|
|
|
|
|
253 |
],
|
254 |
-
allow_flagging=
|
255 |
-
api_name=
|
256 |
theme=gr.themes.Soft(),
|
257 |
-
concurrency_limit=
|
258 |
-
examples=
|
259 |
-
|
260 |
-
[str(Path("examples/1310.4546v1.pdf"))],
|
261 |
-
"",
|
262 |
-
"Explain this paper to me like I'm 5 years old",
|
263 |
-
"Fun",
|
264 |
-
"Short (1-2 min)",
|
265 |
-
"English",
|
266 |
-
True
|
267 |
-
],
|
268 |
-
[
|
269 |
-
[],
|
270 |
-
"https://en.wikipedia.org/wiki/Hugging_Face",
|
271 |
-
"How did Hugging Face become so successful?",
|
272 |
-
"Fun",
|
273 |
-
"Short (1-2 min)",
|
274 |
-
"English",
|
275 |
-
False
|
276 |
-
],
|
277 |
-
[
|
278 |
-
[],
|
279 |
-
"https://simple.wikipedia.org/wiki/Taylor_Swift",
|
280 |
-
"Why is Taylor Swift so popular?",
|
281 |
-
"Fun",
|
282 |
-
"Short (1-2 min)",
|
283 |
-
"English",
|
284 |
-
False
|
285 |
-
],
|
286 |
-
],
|
287 |
-
cache_examples=True,
|
288 |
)
|
289 |
|
290 |
if __name__ == "__main__":
|
291 |
-
demo.launch(show_api=
|
|
|
8 |
import time
|
9 |
from pathlib import Path
|
10 |
from tempfile import NamedTemporaryFile
|
11 |
+
from typing import List, Tuple, Optional
|
12 |
|
13 |
# Third-party imports
|
14 |
import gradio as gr
|
15 |
+
import random
|
16 |
from loguru import logger
|
|
|
17 |
from pypdf import PdfReader
|
18 |
from pydub import AudioSegment
|
19 |
|
20 |
# Local imports
|
21 |
+
from constants import (
|
22 |
+
APP_TITLE,
|
23 |
+
CHARACTER_LIMIT,
|
24 |
+
ERROR_MESSAGE_NOT_PDF,
|
25 |
+
ERROR_MESSAGE_NO_INPUT,
|
26 |
+
ERROR_MESSAGE_NOT_SUPPORTED_IN_MELO_TTS,
|
27 |
+
ERROR_MESSAGE_READING_PDF,
|
28 |
+
ERROR_MESSAGE_TOO_LONG,
|
29 |
+
GRADIO_CACHE_DIR,
|
30 |
+
GRADIO_CLEAR_CACHE_OLDER_THAN,
|
31 |
+
MELO_TTS_LANGUAGE_MAPPING,
|
32 |
+
NOT_SUPPORTED_IN_MELO_TTS,
|
33 |
+
SUNO_LANGUAGE_MAPPING,
|
34 |
+
UI_ALLOW_FLAGGING,
|
35 |
+
UI_API_NAME,
|
36 |
+
UI_CACHE_EXAMPLES,
|
37 |
+
UI_CONCURRENCY_LIMIT,
|
38 |
+
UI_DESCRIPTION,
|
39 |
+
UI_EXAMPLES,
|
40 |
+
UI_INPUTS,
|
41 |
+
UI_OUTPUTS,
|
42 |
+
UI_SHOW_API,
|
43 |
+
)
|
44 |
+
from prompts import (
|
45 |
+
LANGUAGE_MODIFIER,
|
46 |
+
LENGTH_MODIFIERS,
|
47 |
+
QUESTION_MODIFIER,
|
48 |
+
SYSTEM_PROMPT,
|
49 |
+
TONE_MODIFIER,
|
50 |
+
)
|
51 |
+
from schema import ShortDialogue, MediumDialogue
|
52 |
+
from utils import generate_podcast_audio, generate_script, parse_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
|
55 |
def generate_podcast(
|
|
|
63 |
) -> Tuple[str, str]:
|
64 |
"""Generate the audio and transcript from the PDFs and/or URL."""
|
65 |
|
|
|
|
|
66 |
text = ""
|
67 |
|
68 |
+
# Choose random number from 0 to 9
|
69 |
+
random_voice_number = random.randint(0, 9) # this is for suno model
|
70 |
+
|
71 |
+
if not use_advanced_audio and language in NOT_SUPPORTED_IN_MELO_TTS:
|
72 |
+
raise gr.Error(ERROR_MESSAGE_NOT_SUPPORTED_IN_MELO_TTS)
|
73 |
|
74 |
# Check if at least one input is provided
|
75 |
if not files and not url:
|
76 |
+
raise gr.Error(ERROR_MESSAGE_NO_INPUT)
|
77 |
|
78 |
# Process PDFs if any
|
79 |
if files:
|
80 |
for file in files:
|
81 |
if not file.lower().endswith(".pdf"):
|
82 |
+
raise gr.Error(ERROR_MESSAGE_NOT_PDF)
|
|
|
|
|
83 |
|
84 |
try:
|
85 |
with Path(file).open("rb") as f:
|
86 |
reader = PdfReader(f)
|
87 |
text += "\n\n".join([page.extract_text() for page in reader.pages])
|
88 |
except Exception as e:
|
89 |
+
raise gr.Error(f"{ERROR_MESSAGE_READING_PDF}: {str(e)}")
|
90 |
|
91 |
# Process URL if provided
|
92 |
if url:
|
|
|
97 |
raise gr.Error(str(e))
|
98 |
|
99 |
# Check total character count
|
100 |
+
if len(text) > CHARACTER_LIMIT:
|
101 |
+
raise gr.Error(ERROR_MESSAGE_TOO_LONG)
|
|
|
|
|
|
|
102 |
|
103 |
# Modify the system prompt based on the user input
|
104 |
modified_system_prompt = SYSTEM_PROMPT
|
105 |
+
|
106 |
if question:
|
107 |
+
modified_system_prompt += f"\n\n{QUESTION_MODIFIER} {question}"
|
108 |
if tone:
|
109 |
+
modified_system_prompt += f"\n\n{TONE_MODIFIER} {tone}."
|
110 |
if length:
|
111 |
+
modified_system_prompt += f"\n\n{LENGTH_MODIFIERS[length]}"
|
|
|
|
|
|
|
|
|
112 |
if language:
|
113 |
+
modified_system_prompt += f"\n\n{LANGUAGE_MODIFIER} {language}."
|
|
|
|
|
114 |
|
115 |
# Call the LLM
|
116 |
if length == "Short (1-2 min)":
|
117 |
llm_output = generate_script(modified_system_prompt, text, ShortDialogue)
|
118 |
else:
|
119 |
llm_output = generate_script(modified_system_prompt, text, MediumDialogue)
|
120 |
+
|
121 |
logger.info(f"Generated dialogue: {llm_output}")
|
122 |
|
123 |
# Process the dialogue
|
|
|
134 |
transcript += speaker + "\n\n"
|
135 |
total_characters += len(line.text)
|
136 |
|
137 |
+
language_for_tts = SUNO_LANGUAGE_MAPPING[language]
|
138 |
|
139 |
if not use_advanced_audio:
|
140 |
language_for_tts = MELO_TTS_LANGUAGE_MAPPING[language_for_tts]
|
141 |
|
142 |
# Get audio file path
|
143 |
audio_file_path = generate_podcast_audio(
|
144 |
+
line.text, line.speaker, language_for_tts, use_advanced_audio, random_voice_number
|
145 |
)
|
146 |
# Read the audio file into an AudioSegment
|
147 |
audio_segment = AudioSegment.from_file(audio_file_path)
|
|
|
151 |
combined_audio = sum(audio_segments)
|
152 |
|
153 |
# Export the combined audio to a temporary file
|
154 |
+
temporary_directory = GRADIO_CACHE_DIR
|
155 |
os.makedirs(temporary_directory, exist_ok=True)
|
156 |
|
157 |
temporary_file = NamedTemporaryFile(
|
|
|
163 |
|
164 |
# Delete any files in the temp directory that end with .mp3 and are over a day old
|
165 |
for file in glob.glob(f"{temporary_directory}*.mp3"):
|
166 |
+
if (
|
167 |
+
os.path.isfile(file)
|
168 |
+
and time.time() - os.path.getmtime(file) > GRADIO_CLEAR_CACHE_OLDER_THAN
|
169 |
+
):
|
170 |
os.remove(file)
|
171 |
|
172 |
logger.info(f"Generated {total_characters} characters of audio")
|
|
|
175 |
|
176 |
|
177 |
demo = gr.Interface(
|
178 |
+
title=APP_TITLE,
|
179 |
+
description=UI_DESCRIPTION,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
fn=generate_podcast,
|
181 |
inputs=[
|
182 |
gr.File(
|
183 |
+
label=UI_INPUTS["file_upload"]["label"], # Step 1: File upload
|
184 |
+
file_types=UI_INPUTS["file_upload"]["file_types"],
|
185 |
+
file_count=UI_INPUTS["file_upload"]["file_count"],
|
186 |
),
|
187 |
gr.Textbox(
|
188 |
+
label=UI_INPUTS["url"]["label"], # Step 2: URL
|
189 |
+
placeholder=UI_INPUTS["url"]["placeholder"],
|
190 |
),
|
191 |
+
gr.Textbox(label=UI_INPUTS["question"]["label"]), # Step 3: Question
|
192 |
gr.Dropdown(
|
193 |
+
label=UI_INPUTS["tone"]["label"], # Step 4: Tone
|
194 |
+
choices=UI_INPUTS["tone"]["choices"],
|
195 |
+
value=UI_INPUTS["tone"]["value"],
|
196 |
),
|
197 |
gr.Dropdown(
|
198 |
+
label=UI_INPUTS["length"]["label"], # Step 5: Length
|
199 |
+
choices=UI_INPUTS["length"]["choices"],
|
200 |
+
value=UI_INPUTS["length"]["value"],
|
201 |
),
|
202 |
gr.Dropdown(
|
203 |
+
choices=UI_INPUTS["language"]["choices"], # Step 6: Language
|
204 |
+
value=UI_INPUTS["language"]["value"],
|
205 |
+
label=UI_INPUTS["language"]["label"],
|
206 |
),
|
207 |
gr.Checkbox(
|
208 |
+
label=UI_INPUTS["advanced_audio"]["label"],
|
209 |
+
value=UI_INPUTS["advanced_audio"]["value"],
|
210 |
+
),
|
211 |
],
|
212 |
outputs=[
|
213 |
+
gr.Audio(
|
214 |
+
label=UI_OUTPUTS["audio"]["label"], format=UI_OUTPUTS["audio"]["format"]
|
215 |
+
),
|
216 |
+
gr.Markdown(label=UI_OUTPUTS["transcript"]["label"]),
|
217 |
],
|
218 |
+
allow_flagging=UI_ALLOW_FLAGGING,
|
219 |
+
api_name=UI_API_NAME,
|
220 |
theme=gr.themes.Soft(),
|
221 |
+
concurrency_limit=UI_CONCURRENCY_LIMIT,
|
222 |
+
# examples=UI_EXAMPLES,
|
223 |
+
# cache_examples=UI_CACHE_EXAMPLES,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
)
|
225 |
|
226 |
if __name__ == "__main__":
|
227 |
+
demo.launch(show_api=UI_SHOW_API)
|
constants.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
constants.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
# Key constants
|
10 |
+
APP_TITLE = "Open NotebookLM"
|
11 |
+
CHARACTER_LIMIT = 100_000
|
12 |
+
|
13 |
+
# Gradio-related constants
|
14 |
+
GRADIO_CACHE_DIR = "./gradio_cached_examples/tmp/"
|
15 |
+
GRADIO_CLEAR_CACHE_OLDER_THAN = 1 * 24 * 60 * 60 # 1 day
|
16 |
+
|
17 |
+
# Error messages-related constants
|
18 |
+
ERROR_MESSAGE_NO_INPUT = "Please provide at least one PDF file or a URL."
|
19 |
+
ERROR_MESSAGE_NOT_PDF = "The provided file is not a PDF. Please upload only PDF files."
|
20 |
+
ERROR_MESSAGE_NOT_SUPPORTED_IN_MELO_TTS = "The selected language is not supported without advanced audio generation. Please enable advanced audio generation or choose a supported language."
|
21 |
+
ERROR_MESSAGE_READING_PDF = "Error reading the PDF file"
|
22 |
+
ERROR_MESSAGE_TOO_LONG = "The total content is too long. Please ensure the combined text from PDFs and URL is fewer than {CHARACTER_LIMIT} characters."
|
23 |
+
|
24 |
+
# Fireworks API-related constants
|
25 |
+
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
|
26 |
+
FIREWORKS_BASE_URL = "https://api.fireworks.ai/inference/v1"
|
27 |
+
FIREWORKS_MAX_TOKENS = 16_384
|
28 |
+
FIREWORKS_MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
|
29 |
+
FIREWORKS_TEMPERATURE = 0.1
|
30 |
+
FIREWORKS_JSON_RETRY_ATTEMPTS = 3
|
31 |
+
|
32 |
+
# MeloTTS
|
33 |
+
MELO_API_NAME = "/synthesize"
|
34 |
+
MELO_TTS_SPACES_ID = "mrfakename/MeloTTS"
|
35 |
+
MELO_RETRY_ATTEMPTS = 3
|
36 |
+
MELO_RETRY_DELAY = 5 # in seconds
|
37 |
+
|
38 |
+
MELO_TTS_LANGUAGE_MAPPING = {
|
39 |
+
"en": "EN",
|
40 |
+
"es": "ES",
|
41 |
+
"fr": "FR",
|
42 |
+
"zh": "ZJ",
|
43 |
+
"ja": "JP",
|
44 |
+
"ko": "KR",
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
# Suno related constants
|
49 |
+
SUNO_LANGUAGE_MAPPING = {
|
50 |
+
"English": "en",
|
51 |
+
"Chinese": "zh",
|
52 |
+
"French": "fr",
|
53 |
+
"German": "de",
|
54 |
+
"Hindi": "hi",
|
55 |
+
"Italian": "it",
|
56 |
+
"Japanese": "ja",
|
57 |
+
"Korean": "ko",
|
58 |
+
"Polish": "pl",
|
59 |
+
"Portuguese": "pt",
|
60 |
+
"Russian": "ru",
|
61 |
+
"Spanish": "es",
|
62 |
+
"Turkish": "tr",
|
63 |
+
}
|
64 |
+
|
65 |
+
# General audio-related constants
|
66 |
+
NOT_SUPPORTED_IN_MELO_TTS = list(
|
67 |
+
set(SUNO_LANGUAGE_MAPPING.values()) - set(MELO_TTS_LANGUAGE_MAPPING.keys())
|
68 |
+
)
|
69 |
+
NOT_SUPPORTED_IN_MELO_TTS = [
|
70 |
+
key for key, id in SUNO_LANGUAGE_MAPPING.items() if id in NOT_SUPPORTED_IN_MELO_TTS
|
71 |
+
]
|
72 |
+
|
73 |
+
# Jina Reader-related constants
|
74 |
+
JINA_READER_URL = "https://r.jina.ai/"
|
75 |
+
JINA_RETRY_ATTEMPTS = 3
|
76 |
+
JINA_RETRY_DELAY = 5 # in seconds
|
77 |
+
|
78 |
+
# UI-related constants
|
79 |
+
UI_DESCRIPTION = """
|
80 |
+
<table style="border-collapse: collapse; border: none; padding: 20px;">
|
81 |
+
<tr style="border: none;">
|
82 |
+
<td style="border: none; vertical-align: top; padding-right: 30px; padding-left: 30px;">
|
83 |
+
<img src="https://raw.githubusercontent.com/gabrielchua/daily-ai-papers/main/_includes/icon.png" alt="Open NotebookLM" width="120" style="margin-bottom: 10px;">
|
84 |
+
</td>
|
85 |
+
<td style="border: none; vertical-align: top; padding: 10px;">
|
86 |
+
<p style="margin-bottom: 15px;">Convert your PDFs into podcasts with open-source AI models (<a href="https://huggingface.co/meta-llama/Llama-3.1-405B">Llama 3.1 405B</a>, <a href="https://huggingface.co/myshell-ai/MeloTTS-English">MeloTTS</a>, <a href="https://huggingface.co/suno/bark">Bark</a>).</p>
|
87 |
+
<p style="margin-top: 15px;">Note: Only the text content of the PDFs will be processed. Images and tables are not included. The total content should be no more than 100,000 characters due to the context length of Llama 3.1 405B.</p>
|
88 |
+
</td>
|
89 |
+
</tr>
|
90 |
+
</table>
|
91 |
+
"""
|
92 |
+
UI_AVAILABLE_LANGUAGES = list(set(SUNO_LANGUAGE_MAPPING.keys()))
|
93 |
+
UI_INPUTS = {
|
94 |
+
"file_upload": {
|
95 |
+
"label": "1. 📄 Upload your PDF(s)",
|
96 |
+
"file_types": [".pdf"],
|
97 |
+
"file_count": "multiple",
|
98 |
+
},
|
99 |
+
"url": {
|
100 |
+
"label": "2. 🔗 Paste a URL (optional)",
|
101 |
+
"placeholder": "Enter a URL to include its content",
|
102 |
+
},
|
103 |
+
"question": {
|
104 |
+
"label": "3. 🤔 Do you have a specific question or topic in mind?",
|
105 |
+
"placeholder": "Enter a question or topic",
|
106 |
+
},
|
107 |
+
"tone": {
|
108 |
+
"label": "4. 🎭 Choose the tone",
|
109 |
+
"choices": ["Fun", "Formal"],
|
110 |
+
"value": "Fun",
|
111 |
+
},
|
112 |
+
"length": {
|
113 |
+
"label": "5. ⏱️ Choose the length",
|
114 |
+
"choices": ["Short (1-2 min)", "Medium (3-5 min)"],
|
115 |
+
"value": "Medium (3-5 min)",
|
116 |
+
},
|
117 |
+
"language": {
|
118 |
+
"label": "6. 🌐 Choose the language",
|
119 |
+
"choices": UI_AVAILABLE_LANGUAGES,
|
120 |
+
"value": "English",
|
121 |
+
},
|
122 |
+
"advanced_audio": {
|
123 |
+
"label": "7. 🔄 Use advanced audio generation? (Experimental)",
|
124 |
+
"value": False,
|
125 |
+
},
|
126 |
+
}
|
127 |
+
UI_OUTPUTS = {
|
128 |
+
"audio": {"label": "🔊 Podcast", "format": "mp3"},
|
129 |
+
"transcript": {
|
130 |
+
"label": "📜 Transcript",
|
131 |
+
},
|
132 |
+
}
|
133 |
+
UI_API_NAME = "generate_podcast"
|
134 |
+
UI_ALLOW_FLAGGING = "never"
|
135 |
+
UI_CONCURRENCY_LIMIT = 3
|
136 |
+
UI_EXAMPLES = [
|
137 |
+
[
|
138 |
+
[str(Path("examples/1310.4546v1.pdf"))],
|
139 |
+
"",
|
140 |
+
"Explain this paper to me like I'm 5 years old",
|
141 |
+
"Fun",
|
142 |
+
"Short (1-2 min)",
|
143 |
+
"English",
|
144 |
+
True,
|
145 |
+
],
|
146 |
+
[
|
147 |
+
[],
|
148 |
+
"https://en.wikipedia.org/wiki/Hugging_Face",
|
149 |
+
"How did Hugging Face become so successful?",
|
150 |
+
"Fun",
|
151 |
+
"Short (1-2 min)",
|
152 |
+
"English",
|
153 |
+
False,
|
154 |
+
],
|
155 |
+
[
|
156 |
+
[],
|
157 |
+
"https://simple.wikipedia.org/wiki/Taylor_Swift",
|
158 |
+
"Why is Taylor Swift so popular?",
|
159 |
+
"Fun",
|
160 |
+
"Short (1-2 min)",
|
161 |
+
"English",
|
162 |
+
False,
|
163 |
+
],
|
164 |
+
]
|
165 |
+
UI_CACHE_EXAMPLES = True
|
166 |
+
UI_SHOW_API = True
|
prompts.py
CHANGED
@@ -51,5 +51,18 @@ You are a world-class podcast producer tasked with transforming the provided inp
|
|
51 |
- Include brief "breather" moments for listeners to absorb complex information
|
52 |
- End on a high note, perhaps with a thought-provoking question or a call-to-action for listeners
|
53 |
|
|
|
|
|
54 |
Remember: Always reply in valid JSON format, without code blocks. Begin directly with the JSON output.
|
55 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
- Include brief "breather" moments for listeners to absorb complex information
|
52 |
- End on a high note, perhaps with a thought-provoking question or a call-to-action for listeners
|
53 |
|
54 |
+
IMPORTANT RULE: Each line of dialogue should be no more than 100 characters (e.g., can finish within 5-8 seconds)
|
55 |
+
|
56 |
Remember: Always reply in valid JSON format, without code blocks. Begin directly with the JSON output.
|
57 |
"""
|
58 |
+
|
59 |
+
QUESTION_MODIFIER = "PLEASE ANSWER THE FOLLOWING QN:"
|
60 |
+
|
61 |
+
TONE_MODIFIER = "TONE: The tone of the podcast should be"
|
62 |
+
|
63 |
+
LANGUAGE_MODIFIER = "OUTPUT LANGUAGE <IMPORTANT>: The the podcast should be"
|
64 |
+
|
65 |
+
LENGTH_MODIFIERS = {
|
66 |
+
"Short (1-2 min)": "Keep the podcast brief, around 1-2 minutes long.",
|
67 |
+
"Medium (3-5 min)": "Aim for a moderate length, about 3-5 minutes.",
|
68 |
+
}
|
schema.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
schema.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Literal, List
|
6 |
+
|
7 |
+
from pydantic import BaseModel, Field
|
8 |
+
|
9 |
+
|
10 |
+
class DialogueItem(BaseModel):
|
11 |
+
"""A single dialogue item."""
|
12 |
+
|
13 |
+
speaker: Literal["Host (Jane)", "Guest"]
|
14 |
+
text: str
|
15 |
+
|
16 |
+
|
17 |
+
class ShortDialogue(BaseModel):
|
18 |
+
"""The dialogue between the host and guest."""
|
19 |
+
|
20 |
+
scratchpad: str
|
21 |
+
name_of_guest: str
|
22 |
+
dialogue: List[DialogueItem] = Field(
|
23 |
+
..., description="A list of dialogue items, typically between 11 to 17 items"
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class MediumDialogue(BaseModel):
|
28 |
+
"""The dialogue between the host and guest."""
|
29 |
+
|
30 |
+
scratchpad: str
|
31 |
+
name_of_guest: str
|
32 |
+
dialogue: List[DialogueItem] = Field(
|
33 |
+
..., description="A list of dialogue items, typically between 19 to 29 items"
|
34 |
+
)
|
utils.py
CHANGED
@@ -2,68 +2,115 @@
|
|
2 |
utils.py
|
3 |
|
4 |
Functions:
|
5 |
-
-
|
6 |
- call_llm: Call the LLM with the given prompt and dialogue format.
|
7 |
-
-
|
|
|
8 |
"""
|
9 |
|
10 |
-
|
11 |
-
import requests
|
12 |
import time
|
|
|
|
|
|
|
|
|
|
|
13 |
from gradio_client import Client
|
14 |
from openai import OpenAI
|
15 |
from pydantic import ValidationError
|
16 |
-
|
17 |
-
from bark import SAMPLE_RATE, generate_audio, preload_models
|
18 |
from scipy.io.wavfile import write as write_wav
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
)
|
|
|
27 |
|
28 |
-
|
|
|
|
|
29 |
|
30 |
-
#
|
31 |
preload_models()
|
32 |
|
33 |
|
34 |
-
def generate_script(
|
|
|
|
|
|
|
|
|
35 |
"""Get the dialogue from the LLM."""
|
36 |
-
# Load as python object
|
37 |
-
try:
|
38 |
-
response = call_llm(system_prompt, input_text, output_model)
|
39 |
-
dialogue = output_model.model_validate_json(response.choices[0].message.content)
|
40 |
-
except ValidationError as e:
|
41 |
-
error_message = f"Failed to parse dialogue JSON: {e}"
|
42 |
-
system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
|
43 |
-
response = call_llm(system_prompt_with_error, input_text, output_model)
|
44 |
-
dialogue = output_model.model_validate_json(response.choices[0].message.content)
|
45 |
-
|
46 |
-
# Call the LLM again to improve the dialogue
|
47 |
-
system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{dialogue}."
|
48 |
-
response = call_llm(
|
49 |
-
system_prompt_with_dialogue, "Please improve the dialogue.", output_model
|
50 |
-
)
|
51 |
-
improved_dialogue = output_model.model_validate_json(
|
52 |
-
response.choices[0].message.content
|
53 |
-
)
|
54 |
-
return improved_dialogue
|
55 |
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
"""Call the LLM with the given prompt and dialogue format."""
|
59 |
-
response =
|
60 |
messages=[
|
61 |
{"role": "system", "content": system_prompt},
|
62 |
{"role": "user", "content": text},
|
63 |
],
|
64 |
-
model=
|
65 |
-
max_tokens=
|
66 |
-
temperature=
|
67 |
response_format={
|
68 |
"type": "json_object",
|
69 |
"schema": dialogue_format.model_json_schema(),
|
@@ -74,46 +121,70 @@ def call_llm(system_prompt: str, text: str, dialogue_format):
|
|
74 |
|
75 |
def parse_url(url: str) -> str:
|
76 |
"""Parse the given URL and return the text content."""
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
return response.text
|
80 |
|
81 |
|
82 |
-
def generate_podcast_audio(
|
83 |
-
|
|
|
|
|
84 |
if use_advanced_audio:
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
# save audio to disk
|
90 |
-
write_wav(file_path, SAMPLE_RATE, audio_array)
|
91 |
-
|
92 |
-
return file_path
|
93 |
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
else
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
utils.py
|
3 |
|
4 |
Functions:
|
5 |
+
- generate_script: Get the dialogue from the LLM.
|
6 |
- call_llm: Call the LLM with the given prompt and dialogue format.
|
7 |
+
- parse_url: Parse the given URL and return the text content.
|
8 |
+
- generate_podcast_audio: Generate audio for podcast using TTS or advanced audio models.
|
9 |
"""
|
10 |
|
11 |
+
# Standard library imports
|
|
|
12 |
import time
|
13 |
+
from typing import Any, Union
|
14 |
+
|
15 |
+
# Third-party imports
|
16 |
+
import requests
|
17 |
+
from bark import SAMPLE_RATE, generate_audio, preload_models
|
18 |
from gradio_client import Client
|
19 |
from openai import OpenAI
|
20 |
from pydantic import ValidationError
|
|
|
|
|
21 |
from scipy.io.wavfile import write as write_wav
|
22 |
|
23 |
+
# Local imports
|
24 |
+
from constants import (
|
25 |
+
FIREWORKS_API_KEY,
|
26 |
+
FIREWORKS_BASE_URL,
|
27 |
+
FIREWORKS_MODEL_ID,
|
28 |
+
FIREWORKS_MAX_TOKENS,
|
29 |
+
FIREWORKS_TEMPERATURE,
|
30 |
+
FIREWORKS_JSON_RETRY_ATTEMPTS,
|
31 |
+
MELO_API_NAME,
|
32 |
+
MELO_TTS_SPACES_ID,
|
33 |
+
MELO_RETRY_ATTEMPTS,
|
34 |
+
MELO_RETRY_DELAY,
|
35 |
+
JINA_READER_URL,
|
36 |
+
JINA_RETRY_ATTEMPTS,
|
37 |
+
JINA_RETRY_DELAY,
|
38 |
)
|
39 |
+
from schema import ShortDialogue, MediumDialogue
|
40 |
|
41 |
+
# Initialize clients
|
42 |
+
fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
|
43 |
+
hf_client = Client(MELO_TTS_SPACES_ID)
|
44 |
|
45 |
+
# Download and load all models for Bark
|
46 |
preload_models()
|
47 |
|
48 |
|
49 |
+
def generate_script(
|
50 |
+
system_prompt: str,
|
51 |
+
input_text: str,
|
52 |
+
output_model: Union[ShortDialogue, MediumDialogue],
|
53 |
+
) -> Union[ShortDialogue, MediumDialogue]:
|
54 |
"""Get the dialogue from the LLM."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
# Call the LLM
|
57 |
+
response = call_llm(system_prompt, input_text, output_model)
|
58 |
+
response_json = response.choices[0].message.content
|
59 |
+
|
60 |
+
# Validate the response
|
61 |
+
for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
|
62 |
+
try:
|
63 |
+
first_draft_dialogue = output_model.model_validate_json(response_json)
|
64 |
+
break
|
65 |
+
except ValidationError as e:
|
66 |
+
if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
|
67 |
+
raise ValueError(
|
68 |
+
f"Failed to parse dialogue JSON after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
|
69 |
+
) from e
|
70 |
+
error_message = (
|
71 |
+
f"Failed to parse dialogue JSON (attempt {attempt + 1}): {e}"
|
72 |
+
)
|
73 |
+
# Re-call the LLM with the error message
|
74 |
+
system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
|
75 |
+
response = call_llm(system_prompt_with_error, input_text, output_model)
|
76 |
+
response_json = response.choices[0].message.content
|
77 |
+
first_draft_dialogue = output_model.model_validate_json(response_json)
|
78 |
+
|
79 |
+
# Call the LLM a second time to improve the dialogue
|
80 |
+
system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue}."
|
81 |
+
|
82 |
+
# Validate the response
|
83 |
+
for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
|
84 |
+
try:
|
85 |
+
response = call_llm(
|
86 |
+
system_prompt_with_dialogue,
|
87 |
+
"Please improve the dialogue. Make it more natural and engaging.",
|
88 |
+
output_model,
|
89 |
+
)
|
90 |
+
final_dialogue = output_model.model_validate_json(
|
91 |
+
response.choices[0].message.content
|
92 |
+
)
|
93 |
+
break
|
94 |
+
except ValidationError as e:
|
95 |
+
if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
|
96 |
+
raise ValueError(
|
97 |
+
f"Failed to improve dialogue after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
|
98 |
+
) from e
|
99 |
+
error_message = f"Failed to improve dialogue (attempt {attempt + 1}): {e}"
|
100 |
+
system_prompt_with_dialogue += f"\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
|
101 |
+
return final_dialogue
|
102 |
+
|
103 |
+
|
104 |
+
def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any:
|
105 |
"""Call the LLM with the given prompt and dialogue format."""
|
106 |
+
response = fw_client.chat.completions.create(
|
107 |
messages=[
|
108 |
{"role": "system", "content": system_prompt},
|
109 |
{"role": "user", "content": text},
|
110 |
],
|
111 |
+
model=FIREWORKS_MODEL_ID,
|
112 |
+
max_tokens=FIREWORKS_MAX_TOKENS,
|
113 |
+
temperature=FIREWORKS_TEMPERATURE,
|
114 |
response_format={
|
115 |
"type": "json_object",
|
116 |
"schema": dialogue_format.model_json_schema(),
|
|
|
121 |
|
122 |
def parse_url(url: str) -> str:
|
123 |
"""Parse the given URL and return the text content."""
|
124 |
+
for attempt in range(JINA_RETRY_ATTEMPTS):
|
125 |
+
try:
|
126 |
+
full_url = f"{JINA_READER_URL}{url}"
|
127 |
+
response = requests.get(full_url, timeout=60)
|
128 |
+
response.raise_for_status() # Raise an exception for bad status codes
|
129 |
+
break
|
130 |
+
except requests.RequestException as e:
|
131 |
+
if attempt == JINA_RETRY_ATTEMPTS - 1: # Last attempt
|
132 |
+
raise ValueError(
|
133 |
+
f"Failed to fetch URL after {JINA_RETRY_ATTEMPTS} attempts: {e}"
|
134 |
+
) from e
|
135 |
+
time.sleep(JINA_RETRY_DELAY) # Wait for X second before retrying
|
136 |
return response.text
|
137 |
|
138 |
|
139 |
+
def generate_podcast_audio(
|
140 |
+
text: str, speaker: str, language: str, use_advanced_audio: bool, random_voice_number: int
|
141 |
+
) -> str:
|
142 |
+
"""Generate audio for podcast using TTS or advanced audio models."""
|
143 |
if use_advanced_audio:
|
144 |
+
return _use_suno_model(text, speaker, language, random_voice_number)
|
145 |
+
else:
|
146 |
+
return _use_melotts_api(text, speaker, language)
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
|
149 |
+
def _use_suno_model(text: str, speaker: str, language: str, random_voice_number: int) -> str:
|
150 |
+
"""Generate advanced audio using Bark."""
|
151 |
+
audio_array = generate_audio(
|
152 |
+
text,
|
153 |
+
history_prompt=f"v2/{language}_speaker_{random_voice_number if speaker == 'Host (Jane)' else random_voice_number + 1}",
|
154 |
+
)
|
155 |
+
file_path = f"audio_{language}_{speaker}.mp3"
|
156 |
+
write_wav(file_path, SAMPLE_RATE, audio_array)
|
157 |
+
return file_path
|
158 |
+
|
159 |
+
|
160 |
+
def _use_melotts_api(text: str, speaker: str, language: str) -> str:
|
161 |
+
"""Generate audio using TTS model."""
|
162 |
+
accent, speed = _get_melo_tts_params(speaker, language)
|
163 |
+
|
164 |
+
for attempt in range(MELO_RETRY_ATTEMPTS):
|
165 |
+
try:
|
166 |
+
return hf_client.predict(
|
167 |
+
text=text,
|
168 |
+
language=language,
|
169 |
+
speaker=accent,
|
170 |
+
speed=speed,
|
171 |
+
api_name=MELO_API_NAME,
|
172 |
+
)
|
173 |
+
except Exception as e:
|
174 |
+
if attempt == MELO_RETRY_ATTEMPTS - 1: # Last attempt
|
175 |
+
raise # Re-raise the last exception if all attempts fail
|
176 |
+
time.sleep(MELO_RETRY_DELAY) # Wait for X second before retrying
|
177 |
+
|
178 |
+
|
179 |
+
def _get_melo_tts_params(speaker: str, language: str) -> tuple[str, float]:
|
180 |
+
"""Get TTS parameters based on speaker and language."""
|
181 |
+
if speaker == "Guest":
|
182 |
+
accent = "EN-US" if language == "EN" else language
|
183 |
+
speed = 0.9
|
184 |
+
else: # host
|
185 |
+
accent = "EN-Default" if language == "EN" else language
|
186 |
+
speed = (
|
187 |
+
1.1 if language != "EN" else 1
|
188 |
+
) # if the language is not English, try speeding up so it'll sound different from the host
|
189 |
+
# for non-English, there is only one voice
|
190 |
+
return accent, speed
|