|
import os |
|
from groq import Groq |
|
from gtts import gTTS |
|
from pydantic import BaseModel, ValidationError |
|
from typing import List, Literal |
|
import tiktoken |
|
|
|
groq_client = Groq() |
|
tokenizer = tiktoken.get_encoding("cl100k_base") |
|
|
|
class DialogueItem(BaseModel): |
|
speaker: Literal["Host", "Guest"] |
|
text: str |
|
|
|
class Dialogue(BaseModel): |
|
dialogue: List[DialogueItem] |
|
|
|
def truncate_text(text, max_tokens=2048): |
|
tokens = tokenizer.encode(text) |
|
if len(tokens) > max_tokens: |
|
return tokenizer.decode(tokens[:max_tokens]) |
|
return text |
|
|
|
def generate_script(system_prompt: str, input_text: str, tone: str): |
|
input_text = truncate_text(input_text) |
|
prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}" |
|
|
|
response = groq_client.chat.completions.create( |
|
messages=[ |
|
{"role": "system", "content": prompt}, |
|
], |
|
model="llama2-70b-4096", |
|
max_tokens=2048, |
|
temperature=0.7 |
|
) |
|
|
|
try: |
|
dialogue = Dialogue.model_validate_json(response.choices[0].message.content) |
|
except ValidationError as e: |
|
raise ValueError(f"Failed to parse dialogue JSON: {e}") |
|
|
|
return dialogue |
|
|
|
def generate_audio(text: str, speaker: str) -> str: |
|
tts = gTTS(text, lang='en', tld='com' if speaker == "Host" else 'co.uk') |
|
filename = f"{speaker.lower()}_audio.mp3" |
|
tts.save(filename) |
|
return filename |