File size: 3,572 Bytes
9c20b4e
 
 
 
 
 
 
 
 
f136260
dc06293
9c20b4e
 
 
 
8ddd281
 
 
d336953
dc06293
d336953
9c20b4e
 
 
 
 
8ddd281
 
 
 
9c20b4e
 
9db5d78
9c20b4e
 
 
9db5d78
f17c34f
9c20b4e
 
9db5d78
 
f17c34f
 
 
 
 
 
 
 
 
 
 
9c20b4e
 
 
 
 
 
 
 
 
d336953
9c20b4e
 
 
 
 
 
 
 
 
 
dc06293
 
 
 
 
 
8a1ab06
8ddd281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
utils.py

Functions:
- get_script: Get the dialogue from the LLM.
- call_llm: Call the LLM with the given prompt and dialogue format.
- get_audio: Get the audio from the TTS model from HF Spaces.
"""

import os
import requests
from gradio_client import Client
from openai import OpenAI
from pydantic import ValidationError

from bark import SAMPLE_RATE, generate_audio, preload_models
from scipy.io.wavfile import write as write_wav

MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
JINA_URL = "https://r.jina.ai/"

client = OpenAI(
    base_url="https://api.fireworks.ai/inference/v1",
    api_key=os.getenv("FIREWORKS_API_KEY"),
)

# hf_client = Client("mrfakename/MeloTTS")

# download and load all models
preload_models()


def generate_script(system_prompt: str, input_text: str, output_model):
    """Get the dialogue from the LLM."""
    # Load as python object
    try:
        response = call_llm(system_prompt, input_text, output_model)
        dialogue = output_model.model_validate_json(response.choices[0].message.content)
    except ValidationError as e:
        error_message = f"Failed to parse dialogue JSON: {e}"
        system_prompt_with_error = f"{system_prompt}\n\nPlease return a VALID JSON object. This was the earlier error: {error_message}"
        response = call_llm(system_prompt_with_error, input_text, output_model)
        dialogue = output_model.model_validate_json(response.choices[0].message.content)

    # Call the LLM again to improve the dialogue
    system_prompt_with_dialogue = f"{system_prompt}\n\nHere is the first draft of the dialogue you provided:\n\n{dialogue}."
    response = call_llm(
        system_prompt_with_dialogue, "Please improve the dialogue.", output_model
    )
    improved_dialogue = output_model.model_validate_json(
        response.choices[0].message.content
    )
    return improved_dialogue


def call_llm(system_prompt: str, text: str, dialogue_format):
    """Call the LLM with the given prompt and dialogue format."""
    response = client.chat.completions.create(
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": text},
        ],
        model=MODEL_ID,
        max_tokens=16_384,
        temperature=0.1,
        response_format={
            "type": "json_object",
            "schema": dialogue_format.model_json_schema(),
        },
    )
    return response


def parse_url(url: str) -> str:
    """Parse the given URL and return the text content."""
    full_url = f"{JINA_URL}{url}"
    response = requests.get(full_url, timeout=60)
    return response.text


def generate_audio(text: str, speaker: str, language: str) -> str:

    audio_array = generate_audio(text, history_prompt=f"v2/{language}_speaker_{'1' if speaker == 'Host (Jane)' else '3'}")

    file_path = f"audio_{language}_{speaker}.mp3"

    # save audio to disk
    write_wav(file_path, SAMPLE_RATE, audio_array)

    return file_path


    # """Get the audio from the TTS model from HF Spaces and adjust pitch if necessary."""
    # if speaker == "Guest":
    #     accent = "EN-US" if language == "EN" else language
    #     speed = 0.9
    # else:  # host
    #     accent = "EN-Default" if language == "EN" else language
    #     speed = 1
    # if language != "EN" and speaker != "Guest":
    #     speed = 1.1

    # # Generate audio
    # result = hf_client.predict(
    #     text=text,
    #     language=language,
    #     speaker=accent,
    #     speed=speed,
    #     api_name="/synthesize",
    # )
    # return result