gabriel chua commited on
Commit
ab25593
·
unverified ·
1 Parent(s): 5deb312

Chore: Clean up code (#2)

Browse files
Files changed (5) hide show
  1. app.py +87 -151
  2. constants.py +166 -0
  3. prompts.py +13 -0
  4. schema.py +34 -0
  5. utils.py +148 -77
app.py CHANGED
@@ -8,69 +8,48 @@ import os
8
  import time
9
  from pathlib import Path
10
  from tempfile import NamedTemporaryFile
11
- from typing import List, Literal, Tuple, Optional
12
 
13
  # Third-party imports
14
  import gradio as gr
 
15
  from loguru import logger
16
- from pydantic import BaseModel, Field
17
  from pypdf import PdfReader
18
  from pydub import AudioSegment
19
 
20
  # Local imports
21
- from prompts import SYSTEM_PROMPT
22
- from utils import generate_script, generate_podcast_audio, parse_url
23
-
24
-
25
- class DialogueItem(BaseModel):
26
- """A single dialogue item."""
27
-
28
- speaker: Literal["Host (Jane)", "Guest"]
29
- text: str
30
-
31
-
32
- class ShortDialogue(BaseModel):
33
- """The dialogue between the host and guest."""
34
-
35
- scratchpad: str
36
- name_of_guest: str
37
- dialogue: List[DialogueItem] = Field(..., description="A list of dialogue items, typically between 5 to 9 items")
38
-
39
-
40
- class MediumDialogue(BaseModel):
41
- """The dialogue between the host and guest."""
42
-
43
- scratchpad: str
44
- name_of_guest: str
45
- dialogue: List[DialogueItem] = Field(..., description="A list of dialogue items, typically between 8 to 13 items")
46
-
47
-
48
- LANGUAGE_MAPPING = {
49
- "English": "en",
50
- "Chinese": "zh",
51
- "French": "fr",
52
- "German": "de",
53
- "Hindi": "hi",
54
- "Italian": "it",
55
- "Japanese": "ja",
56
- "Korean": "ko",
57
- "Polish": "pl",
58
- "Portuguese": "pt",
59
- "Russian": "ru",
60
- "Spanish": "es",
61
- "Turkish": "tr"
62
- }
63
-
64
- MELO_TTS_LANGUAGE_MAPPING = {
65
- "en": "EN",
66
- "es": "ES",
67
- "fr": "FR",
68
- "zh": "ZJ",
69
- "ja": "JP",
70
- "ko": "KR",
71
- }
72
-
73
-
74
 
75
 
76
  def generate_podcast(
@@ -84,32 +63,30 @@ def generate_podcast(
84
  ) -> Tuple[str, str]:
85
  """Generate the audio and transcript from the PDFs and/or URL."""
86
 
87
-
88
-
89
  text = ""
90
 
91
- # Check if the selected language is supported by MeloTTS when not using advanced audio
92
- if not use_advanced_audio and language in ['German', 'Hindi', 'Italian', 'Polish', 'Portuguese', 'Russian', 'Turkish']:
93
- raise gr.Error(f"The selected language '{language}' is not supported without advanced audio generation. Please enable advanced audio generation or choose a supported language.")
 
 
94
 
95
  # Check if at least one input is provided
96
  if not files and not url:
97
- raise gr.Error("Please provide at least one PDF file or a URL.")
98
 
99
  # Process PDFs if any
100
  if files:
101
  for file in files:
102
  if not file.lower().endswith(".pdf"):
103
- raise gr.Error(
104
- f"File {file} is not a PDF. Please upload only PDF files."
105
- )
106
 
107
  try:
108
  with Path(file).open("rb") as f:
109
  reader = PdfReader(f)
110
  text += "\n\n".join([page.extract_text() for page in reader.pages])
111
  except Exception as e:
112
- raise gr.Error(f"Error reading the PDF file {file}: {str(e)}")
113
 
114
  # Process URL if provided
115
  if url:
@@ -120,34 +97,27 @@ def generate_podcast(
120
  raise gr.Error(str(e))
121
 
122
  # Check total character count
123
- if len(text) > 100000:
124
- raise gr.Error(
125
- "The total content is too long. Please ensure the combined text from PDFs and URL is fewer than ~100,000 characters."
126
- )
127
-
128
 
129
  # Modify the system prompt based on the user input
130
  modified_system_prompt = SYSTEM_PROMPT
 
131
  if question:
132
- modified_system_prompt += f"\n\PLEASE ANSWER THE FOLLOWING QN: {question}"
133
  if tone:
134
- modified_system_prompt += f"\n\nTONE: The tone of the podcast should be {tone}."
135
  if length:
136
- length_instructions = {
137
- "Short (1-2 min)": "Keep the podcast brief, around 1-2 minutes long.",
138
- "Medium (3-5 min)": "Aim for a moderate length, about 3-5 minutes.",
139
- }
140
- modified_system_prompt += f"\n\nLENGTH: {length_instructions[length]}"
141
  if language:
142
- modified_system_prompt += (
143
- f"\n\nOUTPUT LANGUAGE <IMPORTANT>: The the podcast should be {language}."
144
- )
145
 
146
  # Call the LLM
147
  if length == "Short (1-2 min)":
148
  llm_output = generate_script(modified_system_prompt, text, ShortDialogue)
149
  else:
150
  llm_output = generate_script(modified_system_prompt, text, MediumDialogue)
 
151
  logger.info(f"Generated dialogue: {llm_output}")
152
 
153
  # Process the dialogue
@@ -164,14 +134,14 @@ def generate_podcast(
164
  transcript += speaker + "\n\n"
165
  total_characters += len(line.text)
166
 
167
- language_for_tts = LANGUAGE_MAPPING[language]
168
 
169
  if not use_advanced_audio:
170
  language_for_tts = MELO_TTS_LANGUAGE_MAPPING[language_for_tts]
171
 
172
  # Get audio file path
173
  audio_file_path = generate_podcast_audio(
174
- line.text, line.speaker, language_for_tts, use_advanced_audio
175
  )
176
  # Read the audio file into an AudioSegment
177
  audio_segment = AudioSegment.from_file(audio_file_path)
@@ -181,7 +151,7 @@ def generate_podcast(
181
  combined_audio = sum(audio_segments)
182
 
183
  # Export the combined audio to a temporary file
184
- temporary_directory = "./gradio_cached_examples/tmp/"
185
  os.makedirs(temporary_directory, exist_ok=True)
186
 
187
  temporary_file = NamedTemporaryFile(
@@ -193,7 +163,10 @@ def generate_podcast(
193
 
194
  # Delete any files in the temp directory that end with .mp3 and are over a day old
195
  for file in glob.glob(f"{temporary_directory}*.mp3"):
196
- if os.path.isfile(file) and time.time() - os.path.getmtime(file) > 24 * 60 * 60:
 
 
 
197
  os.remove(file)
198
 
199
  logger.info(f"Generated {total_characters} characters of audio")
@@ -202,90 +175,53 @@ def generate_podcast(
202
 
203
 
204
  demo = gr.Interface(
205
- title="Open NotebookLM",
206
- description="""
207
-
208
- <table style="border-collapse: collapse; border: none; padding: 20px;">
209
- <tr style="border: none;">
210
- <td style="border: none; vertical-align: top; padding-right: 30px; padding-left: 30px;">
211
- <img src="https://raw.githubusercontent.com/gabrielchua/daily-ai-papers/main/_includes/icon.png" alt="Open NotebookLM" width="120" style="margin-bottom: 10px;">
212
- </td>
213
- <td style="border: none; vertical-align: top; padding: 10px;">
214
- <p style="margin-bottom: 15px;"><strong>Convert</strong> your PDFs into podcasts with open-source AI models (Llama 3.1 405B and MeloTTS).</p>
215
- <p style="margin-top: 15px;">Note: Only the text content of the PDFs will be processed. Images and tables are not included. The total content should be no more than 100,000 characters due to the context length of Llama 3.1 405B.</p>
216
- </td>
217
- </tr>
218
- </table>
219
- """,
220
  fn=generate_podcast,
221
  inputs=[
222
  gr.File(
223
- label="1. 📄 Upload your PDF(s)", file_types=[".pdf"], file_count="multiple"
 
 
224
  ),
225
  gr.Textbox(
226
- label="2. 🔗 Paste a URL (optional)",
227
- placeholder="Enter a URL to include its content",
228
  ),
229
- gr.Textbox(label="3. 🤔 Do you have a specific question or topic in mind?"),
230
  gr.Dropdown(
231
- choices=["Fun", "Formal"],
232
- label="4. 🎭 Choose the tone",
233
- value="Fun"
234
  ),
235
  gr.Dropdown(
236
- choices=["Short (1-2 min)", "Medium (3-5 min)"],
237
- label="5. ⏱️ Choose the length",
238
- value="Medium (3-5 min)"
239
  ),
240
  gr.Dropdown(
241
- choices=list(LANGUAGE_MAPPING.keys()),
242
- value="English",
243
- label="6. 🌐 Choose the language"
244
  ),
245
  gr.Checkbox(
246
- label="7. 🔄 Use advanced audio generation? (Experimental)",
247
- value=False
248
- )
249
  ],
250
  outputs=[
251
- gr.Audio(label="Podcast", format="mp3"),
252
- gr.Markdown(label="Transcript"),
 
 
253
  ],
254
- allow_flagging="never",
255
- api_name="generate_podcast",
256
  theme=gr.themes.Soft(),
257
- concurrency_limit=3,
258
- examples=[
259
- [
260
- [str(Path("examples/1310.4546v1.pdf"))],
261
- "",
262
- "Explain this paper to me like I'm 5 years old",
263
- "Fun",
264
- "Short (1-2 min)",
265
- "English",
266
- True
267
- ],
268
- [
269
- [],
270
- "https://en.wikipedia.org/wiki/Hugging_Face",
271
- "How did Hugging Face become so successful?",
272
- "Fun",
273
- "Short (1-2 min)",
274
- "English",
275
- False
276
- ],
277
- [
278
- [],
279
- "https://simple.wikipedia.org/wiki/Taylor_Swift",
280
- "Why is Taylor Swift so popular?",
281
- "Fun",
282
- "Short (1-2 min)",
283
- "English",
284
- False
285
- ],
286
- ],
287
- cache_examples=True,
288
  )
289
 
290
  if __name__ == "__main__":
291
- demo.launch(show_api=True)
 
8
  import time
9
  from pathlib import Path
10
  from tempfile import NamedTemporaryFile
11
+ from typing import List, Tuple, Optional
12
 
13
  # Third-party imports
14
  import gradio as gr
15
+ import random
16
  from loguru import logger
 
17
  from pypdf import PdfReader
18
  from pydub import AudioSegment
19
 
20
  # Local imports
21
+ from constants import (
22
+ APP_TITLE,
23
+ CHARACTER_LIMIT,
24
+ ERROR_MESSAGE_NOT_PDF,
25
+ ERROR_MESSAGE_NO_INPUT,
26
+ ERROR_MESSAGE_NOT_SUPPORTED_IN_MELO_TTS,
27
+ ERROR_MESSAGE_READING_PDF,
28
+ ERROR_MESSAGE_TOO_LONG,
29
+ GRADIO_CACHE_DIR,
30
+ GRADIO_CLEAR_CACHE_OLDER_THAN,
31
+ MELO_TTS_LANGUAGE_MAPPING,
32
+ NOT_SUPPORTED_IN_MELO_TTS,
33
+ SUNO_LANGUAGE_MAPPING,
34
+ UI_ALLOW_FLAGGING,
35
+ UI_API_NAME,
36
+ UI_CACHE_EXAMPLES,
37
+ UI_CONCURRENCY_LIMIT,
38
+ UI_DESCRIPTION,
39
+ UI_EXAMPLES,
40
+ UI_INPUTS,
41
+ UI_OUTPUTS,
42
+ UI_SHOW_API,
43
+ )
44
+ from prompts import (
45
+ LANGUAGE_MODIFIER,
46
+ LENGTH_MODIFIERS,
47
+ QUESTION_MODIFIER,
48
+ SYSTEM_PROMPT,
49
+ TONE_MODIFIER,
50
+ )
51
+ from schema import ShortDialogue, MediumDialogue
52
+ from utils import generate_podcast_audio, generate_script, parse_url
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
  def generate_podcast(
 
63
  ) -> Tuple[str, str]:
64
  """Generate the audio and transcript from the PDFs and/or URL."""
65
 
 
 
66
  text = ""
67
 
68
+ # Choose random number from 0 to 9
69
+ random_voice_number = random.randint(0, 9) # this is for suno model
70
+
71
+ if not use_advanced_audio and language in NOT_SUPPORTED_IN_MELO_TTS:
72
+ raise gr.Error(ERROR_MESSAGE_NOT_SUPPORTED_IN_MELO_TTS)
73
 
74
  # Check if at least one input is provided
75
  if not files and not url:
76
+ raise gr.Error(ERROR_MESSAGE_NO_INPUT)
77
 
78
  # Process PDFs if any
79
  if files:
80
  for file in files:
81
  if not file.lower().endswith(".pdf"):
82
+ raise gr.Error(ERROR_MESSAGE_NOT_PDF)
 
 
83
 
84
  try:
85
  with Path(file).open("rb") as f:
86
  reader = PdfReader(f)
87
  text += "\n\n".join([page.extract_text() for page in reader.pages])
88
  except Exception as e:
89
+ raise gr.Error(f"{ERROR_MESSAGE_READING_PDF}: {str(e)}")
90
 
91
  # Process URL if provided
92
  if url:
 
97
  raise gr.Error(str(e))
98
 
99
  # Check total character count
100
+ if len(text) > CHARACTER_LIMIT:
101
+ raise gr.Error(ERROR_MESSAGE_TOO_LONG)
 
 
 
102
 
103
  # Modify the system prompt based on the user input
104
  modified_system_prompt = SYSTEM_PROMPT
105
+
106
  if question:
107
+ modified_system_prompt += f"\n\n{QUESTION_MODIFIER} {question}"
108
  if tone:
109
+ modified_system_prompt += f"\n\n{TONE_MODIFIER} {tone}."
110
  if length:
111
+ modified_system_prompt += f"\n\n{LENGTH_MODIFIERS[length]}"
 
 
 
 
112
  if language:
113
+ modified_system_prompt += f"\n\n{LANGUAGE_MODIFIER} {language}."
 
 
114
 
115
  # Call the LLM
116
  if length == "Short (1-2 min)":
117
  llm_output = generate_script(modified_system_prompt, text, ShortDialogue)
118
  else:
119
  llm_output = generate_script(modified_system_prompt, text, MediumDialogue)
120
+
121
  logger.info(f"Generated dialogue: {llm_output}")
122
 
123
  # Process the dialogue
 
134
  transcript += speaker + "\n\n"
135
  total_characters += len(line.text)
136
 
137
+ language_for_tts = SUNO_LANGUAGE_MAPPING[language]
138
 
139
  if not use_advanced_audio:
140
  language_for_tts = MELO_TTS_LANGUAGE_MAPPING[language_for_tts]
141
 
142
  # Get audio file path
143
  audio_file_path = generate_podcast_audio(
144
+ line.text, line.speaker, language_for_tts, use_advanced_audio, random_voice_number
145
  )
146
  # Read the audio file into an AudioSegment
147
  audio_segment = AudioSegment.from_file(audio_file_path)
 
151
  combined_audio = sum(audio_segments)
152
 
153
  # Export the combined audio to a temporary file
154
+ temporary_directory = GRADIO_CACHE_DIR
155
  os.makedirs(temporary_directory, exist_ok=True)
156
 
157
  temporary_file = NamedTemporaryFile(
 
163
 
164
  # Delete any files in the temp directory that end with .mp3 and are over a day old
165
  for file in glob.glob(f"{temporary_directory}*.mp3"):
166
+ if (
167
+ os.path.isfile(file)
168
+ and time.time() - os.path.getmtime(file) > GRADIO_CLEAR_CACHE_OLDER_THAN
169
+ ):
170
  os.remove(file)
171
 
172
  logger.info(f"Generated {total_characters} characters of audio")
 
175
 
176
 
177
  demo = gr.Interface(
178
+ title=APP_TITLE,
179
+ description=UI_DESCRIPTION,
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  fn=generate_podcast,
181
  inputs=[
182
  gr.File(
183
+ label=UI_INPUTS["file_upload"]["label"], # Step 1: File upload
184
+ file_types=UI_INPUTS["file_upload"]["file_types"],
185
+ file_count=UI_INPUTS["file_upload"]["file_count"],
186
  ),
187
  gr.Textbox(
188
+ label=UI_INPUTS["url"]["label"], # Step 2: URL
189
+ placeholder=UI_INPUTS["url"]["placeholder"],
190
  ),
191
+ gr.Textbox(label=UI_INPUTS["question"]["label"]), # Step 3: Question
192
  gr.Dropdown(
193
+ label=UI_INPUTS["tone"]["label"], # Step 4: Tone
194
+ choices=UI_INPUTS["tone"]["choices"],
195
+ value=UI_INPUTS["tone"]["value"],
196
  ),
197
  gr.Dropdown(
198
+ label=UI_INPUTS["length"]["label"], # Step 5: Length
199
+ choices=UI_INPUTS["length"]["choices"],
200
+ value=UI_INPUTS["length"]["value"],
201
  ),
202
  gr.Dropdown(
203
+ choices=UI_INPUTS["language"]["choices"], # Step 6: Language
204
+ value=UI_INPUTS["language"]["value"],
205
+ label=UI_INPUTS["language"]["label"],
206
  ),
207
  gr.Checkbox(
208
+ label=UI_INPUTS["advanced_audio"]["label"],
209
+ value=UI_INPUTS["advanced_audio"]["value"],
210
+ ),
211
  ],
212
  outputs=[
213
+ gr.Audio(
214
+ label=UI_OUTPUTS["audio"]["label"], format=UI_OUTPUTS["audio"]["format"]
215
+ ),
216
+ gr.Markdown(label=UI_OUTPUTS["transcript"]["label"]),
217
  ],
218
+ allow_flagging=UI_ALLOW_FLAGGING,
219
+ api_name=UI_API_NAME,
220
  theme=gr.themes.Soft(),
221
+ concurrency_limit=UI_CONCURRENCY_LIMIT,
222
+ # examples=UI_EXAMPLES,
223
+ # cache_examples=UI_CACHE_EXAMPLES,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  )
225
 
226
  if __name__ == "__main__":
227
+ demo.launch(show_api=UI_SHOW_API)
constants.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ constants.py
3
+ """
4
+
5
+ import os
6
+
7
+ from pathlib import Path
8
+
9
+ # Key constants
10
+ APP_TITLE = "Open NotebookLM"
11
+ CHARACTER_LIMIT = 100_000
12
+
13
+ # Gradio-related constants
14
+ GRADIO_CACHE_DIR = "./gradio_cached_examples/tmp/"
15
+ GRADIO_CLEAR_CACHE_OLDER_THAN = 1 * 24 * 60 * 60 # 1 day
16
+
17
+ # Error messages-related constants
18
+ ERROR_MESSAGE_NO_INPUT = "Please provide at least one PDF file or a URL."
19
+ ERROR_MESSAGE_NOT_PDF = "The provided file is not a PDF. Please upload only PDF files."
20
+ ERROR_MESSAGE_NOT_SUPPORTED_IN_MELO_TTS = "The selected language is not supported without advanced audio generation. Please enable advanced audio generation or choose a supported language."
21
+ ERROR_MESSAGE_READING_PDF = "Error reading the PDF file"
22
+ ERROR_MESSAGE_TOO_LONG = "The total content is too long. Please ensure the combined text from PDFs and URL is fewer than {CHARACTER_LIMIT} characters."
23
+
24
+ # Fireworks API-related constants
25
+ FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY")
26
+ FIREWORKS_BASE_URL = "https://api.fireworks.ai/inference/v1"
27
+ FIREWORKS_MAX_TOKENS = 16_384
28
+ FIREWORKS_MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
29
+ FIREWORKS_TEMPERATURE = 0.1
30
+ FIREWORKS_JSON_RETRY_ATTEMPTS = 3
31
+
32
+ # MeloTTS
33
+ MELO_API_NAME = "/synthesize"
34
+ MELO_TTS_SPACES_ID = "mrfakename/MeloTTS"
35
+ MELO_RETRY_ATTEMPTS = 3
36
+ MELO_RETRY_DELAY = 5 # in seconds
37
+
38
+ MELO_TTS_LANGUAGE_MAPPING = {
39
+ "en": "EN",
40
+ "es": "ES",
41
+ "fr": "FR",
42
+ "zh": "ZJ",
43
+ "ja": "JP",
44
+ "ko": "KR",
45
+ }
46
+
47
+
48
+ # Suno related constants
49
+ SUNO_LANGUAGE_MAPPING = {
50
+ "English": "en",
51
+ "Chinese": "zh",
52
+ "French": "fr",
53
+ "German": "de",
54
+ "Hindi": "hi",
55
+ "Italian": "it",
56
+ "Japanese": "ja",
57
+ "Korean": "ko",
58
+ "Polish": "pl",
59
+ "Portuguese": "pt",
60
+ "Russian": "ru",
61
+ "Spanish": "es",
62
+ "Turkish": "tr",
63
+ }
64
+
65
+ # General audio-related constants
66
+ NOT_SUPPORTED_IN_MELO_TTS = list(
67
+ set(SUNO_LANGUAGE_MAPPING.values()) - set(MELO_TTS_LANGUAGE_MAPPING.keys())
68
+ )
69
+ NOT_SUPPORTED_IN_MELO_TTS = [
70
+ key for key, id in SUNO_LANGUAGE_MAPPING.items() if id in NOT_SUPPORTED_IN_MELO_TTS
71
+ ]
72
+
73
+ # Jina Reader-related constants
74
+ JINA_READER_URL = "https://r.jina.ai/"
75
+ JINA_RETRY_ATTEMPTS = 3
76
+ JINA_RETRY_DELAY = 5 # in seconds
77
+
78
+ # UI-related constants
79
+ UI_DESCRIPTION = """
80
+ <table style="border-collapse: collapse; border: none; padding: 20px;">
81
+ <tr style="border: none;">
82
+ <td style="border: none; vertical-align: top; padding-right: 30px; padding-left: 30px;">
83
+ <img src="https://raw.githubusercontent.com/gabrielchua/daily-ai-papers/main/_includes/icon.png" alt="Open NotebookLM" width="120" style="margin-bottom: 10px;">
84
+ </td>
85
+ <td style="border: none; vertical-align: top; padding: 10px;">
86
+ <p style="margin-bottom: 15px;">Convert your PDFs into podcasts with open-source AI models (<a href="https://huggingface.co/meta-llama/Llama-3.1-405B">Llama 3.1 405B</a>, <a href="https://huggingface.co/myshell-ai/MeloTTS-English">MeloTTS</a>, <a href="https://huggingface.co/suno/bark">Bark</a>).</p>
87
+ <p style="margin-top: 15px;">Note: Only the text content of the PDFs will be processed. Images and tables are not included. The total content should be no more than 100,000 characters due to the context length of Llama 3.1 405B.</p>
88
+ </td>
89
+ </tr>
90
+ </table>
91
+ """
92
+ UI_AVAILABLE_LANGUAGES = list(set(SUNO_LANGUAGE_MAPPING.keys()))
93
+ UI_INPUTS = {
94
+ "file_upload": {
95
+ "label": "1. 📄 Upload your PDF(s)",
96
+ "file_types": [".pdf"],
97
+ "file_count": "multiple",
98
+ },
99
+ "url": {
100
+ "label": "2. 🔗 Paste a URL (optional)",
101
+ "placeholder": "Enter a URL to include its content",
102
+ },
103
+ "question": {
104
+ "label": "3. 🤔 Do you have a specific question or topic in mind?",
105
+ "placeholder": "Enter a question or topic",
106
+ },
107
+ "tone": {
108
+ "label": "4. 🎭 Choose the tone",
109
+ "choices": ["Fun", "Formal"],
110
+ "value": "Fun",
111
+ },
112
+ "length": {
113
+ "label": "5. ⏱️ Choose the length",
114
+ "choices": ["Short (1-2 min)", "Medium (3-5 min)"],
115
+ "value": "Medium (3-5 min)",
116
+ },
117
+ "language": {
118
+ "label": "6. 🌐 Choose the language",
119
+ "choices": UI_AVAILABLE_LANGUAGES,
120
+ "value": "English",
121
+ },
122
+ "advanced_audio": {
123
+ "label": "7. 🔄 Use advanced audio generation? (Experimental)",
124
+ "value": False,
125
+ },
126
+ }
127
+ UI_OUTPUTS = {
128
+ "audio": {"label": "🔊 Podcast", "format": "mp3"},
129
+ "transcript": {
130
+ "label": "📜 Transcript",
131
+ },
132
+ }
133
+ UI_API_NAME = "generate_podcast"
134
+ UI_ALLOW_FLAGGING = "never"
135
+ UI_CONCURRENCY_LIMIT = 3
136
+ UI_EXAMPLES = [
137
+ [
138
+ [str(Path("examples/1310.4546v1.pdf"))],
139
+ "",
140
+ "Explain this paper to me like I'm 5 years old",
141
+ "Fun",
142
+ "Short (1-2 min)",
143
+ "English",
144
+ True,
145
+ ],
146
+ [
147
+ [],
148
+ "https://en.wikipedia.org/wiki/Hugging_Face",
149
+ "How did Hugging Face become so successful?",
150
+ "Fun",
151
+ "Short (1-2 min)",
152
+ "English",
153
+ False,
154
+ ],
155
+ [
156
+ [],
157
+ "https://simple.wikipedia.org/wiki/Taylor_Swift",
158
+ "Why is Taylor Swift so popular?",
159
+ "Fun",
160
+ "Short (1-2 min)",
161
+ "English",
162
+ False,
163
+ ],
164
+ ]
165
+ UI_CACHE_EXAMPLES = True
166
+ UI_SHOW_API = True
prompts.py CHANGED
@@ -51,5 +51,18 @@ You are a world-class podcast producer tasked with transforming the provided inp
51
  - Include brief "breather" moments for listeners to absorb complex information
52
  - End on a high note, perhaps with a thought-provoking question or a call-to-action for listeners
53
 
 
 
54
  Remember: Always reply in valid JSON format, without code blocks. Begin directly with the JSON output.
55
  """
 
 
 
 
 
 
 
 
 
 
 
 
51
  - Include brief "breather" moments for listeners to absorb complex information
52
  - End on a high note, perhaps with a thought-provoking question or a call-to-action for listeners
53
 
54
+ IMPORTANT RULE: Each line of dialogue should be no more than 100 characters (e.g., can finish within 5-8 seconds)
55
+
56
  Remember: Always reply in valid JSON format, without code blocks. Begin directly with the JSON output.
57
  """
58
+
59
+ QUESTION_MODIFIER = "PLEASE ANSWER THE FOLLOWING QN:"
60
+
61
+ TONE_MODIFIER = "TONE: The tone of the podcast should be"
62
+
63
+ LANGUAGE_MODIFIER = "OUTPUT LANGUAGE <IMPORTANT>: The the podcast should be"
64
+
65
+ LENGTH_MODIFIERS = {
66
+ "Short (1-2 min)": "Keep the podcast brief, around 1-2 minutes long.",
67
+ "Medium (3-5 min)": "Aim for a moderate length, about 3-5 minutes.",
68
+ }
schema.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ schema.py
3
+ """
4
+
5
+ from typing import Literal, List
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class DialogueItem(BaseModel):
11
+ """A single dialogue item."""
12
+
13
+ speaker: Literal["Host (Jane)", "Guest"]
14
+ text: str
15
+
16
+
17
+ class ShortDialogue(BaseModel):
18
+ """The dialogue between the host and guest."""
19
+
20
+ scratchpad: str
21
+ name_of_guest: str
22
+ dialogue: List[DialogueItem] = Field(
23
+ ..., description="A list of dialogue items, typically between 11 to 17 items"
24
+ )
25
+
26
+
27
+ class MediumDialogue(BaseModel):
28
+ """The dialogue between the host and guest."""
29
+
30
+ scratchpad: str
31
+ name_of_guest: str
32
+ dialogue: List[DialogueItem] = Field(
33
+ ..., description="A list of dialogue items, typically between 19 to 29 items"
34
+ )
utils.py CHANGED
@@ -2,68 +2,115 @@
2
  utils.py
3
 
4
  Functions:
5
- - get_script: Get the dialogue from the LLM.
6
  - call_llm: Call the LLM with the given prompt and dialogue format.
7
- - get_audio: Get the audio from the TTS model from HF Spaces.
 
8
  """
9
 
10
- import os
11
- import requests
12
  import time
 
 
 
 
 
13
  from gradio_client import Client
14
  from openai import OpenAI
15
  from pydantic import ValidationError
16
-
17
- from bark import SAMPLE_RATE, generate_audio, preload_models
18
  from scipy.io.wavfile import write as write_wav
19
 
20
- MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
21
- JINA_URL = "https://r.jina.ai/"
22
-
23
- client = OpenAI(
24
- base_url="https://api.fireworks.ai/inference/v1",
25
- api_key=os.getenv("FIREWORKS_API_KEY"),
 
 
 
 
 
 
 
 
 
26
  )
 
27
 
28
- hf_client = Client("mrfakename/MeloTTS")
 
 
29
 
30
- # download and load all models
31
  preload_models()
32
 
33
 
34
- def generate_script(system_prompt: str, input_text: str, output_model):
 
 
 
 
35
  """Get the dialogue from the LLM."""
36
- # Load as python object
37
- try:
38
- response = call_llm(system_prompt, input_text, output_model)
39
- dialogue = output_model.model_validate_json(response.choices[0].message.content)
40
- except ValidationError as e:
41
- error_message = f"Failed to parse dialogue JSON: {e}"
42
- system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
43
- response = call_llm(system_prompt_with_error, input_text, output_model)
44
- dialogue = output_model.model_validate_json(response.choices[0].message.content)
45
-
46
- # Call the LLM again to improve the dialogue
47
- system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{dialogue}."
48
- response = call_llm(
49
- system_prompt_with_dialogue, "Please improve the dialogue.", output_model
50
- )
51
- improved_dialogue = output_model.model_validate_json(
52
- response.choices[0].message.content
53
- )
54
- return improved_dialogue
55
 
56
-
57
- def call_llm(system_prompt: str, text: str, dialogue_format):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  """Call the LLM with the given prompt and dialogue format."""
59
- response = client.chat.completions.create(
60
  messages=[
61
  {"role": "system", "content": system_prompt},
62
  {"role": "user", "content": text},
63
  ],
64
- model=MODEL_ID,
65
- max_tokens=16_384,
66
- temperature=0.1,
67
  response_format={
68
  "type": "json_object",
69
  "schema": dialogue_format.model_json_schema(),
@@ -74,46 +121,70 @@ def call_llm(system_prompt: str, text: str, dialogue_format):
74
 
75
  def parse_url(url: str) -> str:
76
  """Parse the given URL and return the text content."""
77
- full_url = f"{JINA_URL}{url}"
78
- response = requests.get(full_url, timeout=60)
 
 
 
 
 
 
 
 
 
 
79
  return response.text
80
 
81
 
82
- def generate_podcast_audio(text: str, speaker: str, language: str, use_advanced_audio: bool) -> str:
83
-
 
 
84
  if use_advanced_audio:
85
- audio_array = generate_audio(text, history_prompt=f"v2/{language}_speaker_{'1' if speaker == 'Host (Jane)' else '3'}")
86
-
87
- file_path = f"audio_{language}_{speaker}.mp3"
88
-
89
- # save audio to disk
90
- write_wav(file_path, SAMPLE_RATE, audio_array)
91
-
92
- return file_path
93
 
94
 
95
- else:
96
- if speaker == "Guest":
97
- accent = "EN-US" if language == "EN" else language
98
- speed = 0.9
99
- else: # host
100
- accent = "EN-Default" if language == "EN" else language
101
- speed = 1
102
- if language != "EN" and speaker != "Guest":
103
- speed = 1.1
104
-
105
- # Generate audio
106
- for attempt in range(3):
107
- try:
108
- result = hf_client.predict(
109
- text=text,
110
- language=language,
111
- speaker=accent,
112
- speed=speed,
113
- api_name="/synthesize",
114
- )
115
- return result
116
- except Exception as e:
117
- if attempt == 2: # Last attempt
118
- raise # Re-raise the last exception if all attempts fail
119
- time.sleep(1) # Wait for 1 second before retrying
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  utils.py
3
 
4
  Functions:
5
+ - generate_script: Get the dialogue from the LLM.
6
  - call_llm: Call the LLM with the given prompt and dialogue format.
7
+ - parse_url: Parse the given URL and return the text content.
8
+ - generate_podcast_audio: Generate audio for podcast using TTS or advanced audio models.
9
  """
10
 
11
+ # Standard library imports
 
12
  import time
13
+ from typing import Any, Union
14
+
15
+ # Third-party imports
16
+ import requests
17
+ from bark import SAMPLE_RATE, generate_audio, preload_models
18
  from gradio_client import Client
19
  from openai import OpenAI
20
  from pydantic import ValidationError
 
 
21
  from scipy.io.wavfile import write as write_wav
22
 
23
+ # Local imports
24
+ from constants import (
25
+ FIREWORKS_API_KEY,
26
+ FIREWORKS_BASE_URL,
27
+ FIREWORKS_MODEL_ID,
28
+ FIREWORKS_MAX_TOKENS,
29
+ FIREWORKS_TEMPERATURE,
30
+ FIREWORKS_JSON_RETRY_ATTEMPTS,
31
+ MELO_API_NAME,
32
+ MELO_TTS_SPACES_ID,
33
+ MELO_RETRY_ATTEMPTS,
34
+ MELO_RETRY_DELAY,
35
+ JINA_READER_URL,
36
+ JINA_RETRY_ATTEMPTS,
37
+ JINA_RETRY_DELAY,
38
  )
39
+ from schema import ShortDialogue, MediumDialogue
40
 
41
+ # Initialize clients
42
+ fw_client = OpenAI(base_url=FIREWORKS_BASE_URL, api_key=FIREWORKS_API_KEY)
43
+ hf_client = Client(MELO_TTS_SPACES_ID)
44
 
45
+ # Download and load all models for Bark
46
  preload_models()
47
 
48
 
49
+ def generate_script(
50
+ system_prompt: str,
51
+ input_text: str,
52
+ output_model: Union[ShortDialogue, MediumDialogue],
53
+ ) -> Union[ShortDialogue, MediumDialogue]:
54
  """Get the dialogue from the LLM."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ # Call the LLM
57
+ response = call_llm(system_prompt, input_text, output_model)
58
+ response_json = response.choices[0].message.content
59
+
60
+ # Validate the response
61
+ for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
62
+ try:
63
+ first_draft_dialogue = output_model.model_validate_json(response_json)
64
+ break
65
+ except ValidationError as e:
66
+ if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
67
+ raise ValueError(
68
+ f"Failed to parse dialogue JSON after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
69
+ ) from e
70
+ error_message = (
71
+ f"Failed to parse dialogue JSON (attempt {attempt + 1}): {e}"
72
+ )
73
+ # Re-call the LLM with the error message
74
+ system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
75
+ response = call_llm(system_prompt_with_error, input_text, output_model)
76
+ response_json = response.choices[0].message.content
77
+ first_draft_dialogue = output_model.model_validate_json(response_json)
78
+
79
+ # Call the LLM a second time to improve the dialogue
80
+ system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{first_draft_dialogue}."
81
+
82
+ # Validate the response
83
+ for attempt in range(FIREWORKS_JSON_RETRY_ATTEMPTS):
84
+ try:
85
+ response = call_llm(
86
+ system_prompt_with_dialogue,
87
+ "Please improve the dialogue. Make it more natural and engaging.",
88
+ output_model,
89
+ )
90
+ final_dialogue = output_model.model_validate_json(
91
+ response.choices[0].message.content
92
+ )
93
+ break
94
+ except ValidationError as e:
95
+ if attempt == FIREWORKS_JSON_RETRY_ATTEMPTS - 1: # Last attempt
96
+ raise ValueError(
97
+ f"Failed to improve dialogue after {FIREWORKS_JSON_RETRY_ATTEMPTS} attempts: {e}"
98
+ ) from e
99
+ error_message = f"Failed to improve dialogue (attempt {attempt + 1}): {e}"
100
+ system_prompt_with_dialogue += f"\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
101
+ return final_dialogue
102
+
103
+
104
+ def call_llm(system_prompt: str, text: str, dialogue_format: Any) -> Any:
105
  """Call the LLM with the given prompt and dialogue format."""
106
+ response = fw_client.chat.completions.create(
107
  messages=[
108
  {"role": "system", "content": system_prompt},
109
  {"role": "user", "content": text},
110
  ],
111
+ model=FIREWORKS_MODEL_ID,
112
+ max_tokens=FIREWORKS_MAX_TOKENS,
113
+ temperature=FIREWORKS_TEMPERATURE,
114
  response_format={
115
  "type": "json_object",
116
  "schema": dialogue_format.model_json_schema(),
 
121
 
122
  def parse_url(url: str) -> str:
123
  """Parse the given URL and return the text content."""
124
+ for attempt in range(JINA_RETRY_ATTEMPTS):
125
+ try:
126
+ full_url = f"{JINA_READER_URL}{url}"
127
+ response = requests.get(full_url, timeout=60)
128
+ response.raise_for_status() # Raise an exception for bad status codes
129
+ break
130
+ except requests.RequestException as e:
131
+ if attempt == JINA_RETRY_ATTEMPTS - 1: # Last attempt
132
+ raise ValueError(
133
+ f"Failed to fetch URL after {JINA_RETRY_ATTEMPTS} attempts: {e}"
134
+ ) from e
135
+ time.sleep(JINA_RETRY_DELAY) # Wait for X second before retrying
136
  return response.text
137
 
138
 
139
+ def generate_podcast_audio(
140
+ text: str, speaker: str, language: str, use_advanced_audio: bool, random_voice_number: int
141
+ ) -> str:
142
+ """Generate audio for podcast using TTS or advanced audio models."""
143
  if use_advanced_audio:
144
+ return _use_suno_model(text, speaker, language, random_voice_number)
145
+ else:
146
+ return _use_melotts_api(text, speaker, language)
 
 
 
 
 
147
 
148
 
149
+ def _use_suno_model(text: str, speaker: str, language: str, random_voice_number: int) -> str:
150
+ """Generate advanced audio using Bark."""
151
+ audio_array = generate_audio(
152
+ text,
153
+ history_prompt=f"v2/{language}_speaker_{random_voice_number if speaker == 'Host (Jane)' else random_voice_number + 1}",
154
+ )
155
+ file_path = f"audio_{language}_{speaker}.mp3"
156
+ write_wav(file_path, SAMPLE_RATE, audio_array)
157
+ return file_path
158
+
159
+
160
+ def _use_melotts_api(text: str, speaker: str, language: str) -> str:
161
+ """Generate audio using TTS model."""
162
+ accent, speed = _get_melo_tts_params(speaker, language)
163
+
164
+ for attempt in range(MELO_RETRY_ATTEMPTS):
165
+ try:
166
+ return hf_client.predict(
167
+ text=text,
168
+ language=language,
169
+ speaker=accent,
170
+ speed=speed,
171
+ api_name=MELO_API_NAME,
172
+ )
173
+ except Exception as e:
174
+ if attempt == MELO_RETRY_ATTEMPTS - 1: # Last attempt
175
+ raise # Re-raise the last exception if all attempts fail
176
+ time.sleep(MELO_RETRY_DELAY) # Wait for X second before retrying
177
+
178
+
179
+ def _get_melo_tts_params(speaker: str, language: str) -> tuple[str, float]:
180
+ """Get TTS parameters based on speaker and language."""
181
+ if speaker == "Guest":
182
+ accent = "EN-US" if language == "EN" else language
183
+ speed = 0.9
184
+ else: # host
185
+ accent = "EN-Default" if language == "EN" else language
186
+ speed = (
187
+ 1.1 if language != "EN" else 1
188
+ ) # if the language is not English, try speeding up so it'll sound different from the host
189
+ # for non-English, there is only one voice
190
+ return accent, speed