File size: 3,784 Bytes
1a6d10d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import queue
import threading

import gradio as gr
from dia.model import Dia
from huggingface_hub import InferenceClient

# Hardcoded podcast subject
PODCAST_SUBJECT = "The future of AI and its impact on society"

# Initialize the inference client
client = InferenceClient("Qwen/Qwen2.5-Coder-32B-Instruct", provider="together")
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")

# Queue for audio streaming
audio_queue = queue.Queue()
stop_signal = threading.Event()


def generate_podcast_text(subject):
    prompt = f"""Generate a podcast told by 2 hosts about {subject}.
The podcast should be an insightful discussion, with some amount of playful banter.
Separate dialog as follows using [S1] for the male host and [S2] for the female host, for instance:
[S1] Hello, how are you?
[S2] I'm good, thank you. How are you?
[S1] I'm good, thank you. (laughs)
[S2] Great.
Now go on, make 2 minutes of podcast.
"""
    response = client.chat_completion([{"role": "user", "content": prompt}], max_tokens=1000)
    return response.choices[0].message.content


def split_podcast_into_chunks(podcast_text, chunk_size=10):
    lines = podcast_text.strip().split("\n")
    chunks = []

    for i in range(0, len(lines), chunk_size):
        chunk = "\n".join(lines[i : i + chunk_size])
        chunks.append(chunk)

    return chunks


def process_audio_chunks(podcast_text):
    chunks = split_podcast_into_chunks(podcast_text)

    for chunk in chunks:
        if stop_signal.is_set():
            break

        audio_chunk = model.generate(chunk, use_torch_compile=True, verbose=False)
        audio_queue.put(audio_chunk)

    audio_queue.put(None)


def stream_audio_generator(podcast_text):
    """Creates a generator that yields audio chunks for streaming"""
    stop_signal.clear()

    # Start audio generation in a separate thread
    gen_thread = threading.Thread(target=process_audio_chunks, args=(podcast_text,))
    gen_thread.start()

    sample_rate = 22050

    try:
        while True:
            # Get next chunk from queue
            chunk = audio_queue.get()

            # None signals end of generation
            if chunk is None:
                break

            # Yield the audio chunk with sample rate
            yield (sample_rate, chunk)

    except Exception as e:
        print(f"Error in streaming: {e}")


def stop_generation():
    stop_signal.set()
    return "Generation stopped"


def generate_podcast():
    podcast_text = generate_podcast_text(PODCAST_SUBJECT)
    return podcast_text


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# NotebookLM Podcast Generator")

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown(f"## Current Topic: {PODCAST_SUBJECT}")
            gr.Markdown("This app generates a podcast discussion between two hosts about the specified topic.")

            generate_btn = gr.Button("Generate Podcast Script", variant="primary")
            podcast_output = gr.Textbox(label="Generated Podcast Script", lines=15)

            gr.Markdown("## Audio Preview")
            gr.Markdown("Click below to hear the podcast with realistic voices:")

            with gr.Row():
                start_audio_btn = gr.Button("▶️ Generate Podcast", variant="secondary")
                stop_btn = gr.Button("⏹️ Stop", variant="stop")

            audio_output = gr.Audio(label="Podcast Audio", streaming=True)
            status_text = gr.Textbox(label="Status", visible=True)

    generate_btn.click(fn=generate_podcast, outputs=podcast_output)

    start_audio_btn.click(fn=stream_audio_generator, inputs=podcast_output, outputs=audio_output)
    stop_btn.click(fn=stop_generation, outputs=status_text)

if __name__ == "__main__":
    demo.queue().launch()