from fastapi import FastAPI from pydantic import BaseModel from typing import List import torch from parler_tts import ParlerTTSForConditionalGeneration from transformers import AutoTokenizer import soundfile as sf import io from starlette.responses import StreamingResponse # Initialize the FastAPI app app = FastAPI() # Define a Pydantic model for the items class Item(BaseModel): text: str name: str section: str # Initialize ParlerTTS device = "cuda:0" if torch.cuda.is_available() else "cpu" model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device) tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1") # A simple GET endpoint @app.get("/") def greet_json(): return {"Hello": "World!"} # Function to generate audio from text using ParlerTTS def generate_audio(text, description="Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."): print("A") input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device) print("B") prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device) print("C") generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) print("D") audio_arr = generation.cpu().numpy().squeeze() print("E") return audio_arr, model.config.sampling_rate # A POST endpoint to receive and parse an array of JSON objects @app.post("/") async def create_items(items: List[Item]): processed_items = [] for item in items: print(f"Processing item: {item.text}") # Generate audio print("before") audio_arr, sample_rate = generate_audio(item.text) print("after") # # Create in-memory bytes buffer for audio # audio_bytes = io.BytesIO() # sf.write(audio_bytes, audio_arr, sample_rate, format="WAV") # audio_bytes.seek(0) # Reset buffer position processed_item = { "text": item.text, "name": item.name, "section": item.section, "processed": True, # "audio": StreamingResponse(audio_bytes, media_type="audio/wav") } processed_items.append(processed_item) return {"processed_items": processed_items} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="127.0.0.1", port=8000)