parler / app.py
Carsten Høyer
add parler
910d316
raw
history blame
2.14 kB
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
import io
from starlette.responses import StreamingResponse
# Initialize the FastAPI app
app = FastAPI()
# Define a Pydantic model for the items
class Item(BaseModel):
text: str
name: str
section: str
# Initialize ParlerTTS
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
# A simple GET endpoint
@app.get("/")
def greet_json():
return {"Hello": "World!"}
# Function to generate audio from text using ParlerTTS
def generate_audio(text, description="Neutral voice"):
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
return audio_arr, model.config.sampling_rate
# A POST endpoint to receive and parse an array of JSON objects
@app.post("/")
async def create_items(items: List[Item]):
processed_items = []
for item in items:
# Generate audio
audio_arr, sample_rate = generate_audio(item.text)
# Create in-memory bytes buffer for audio
audio_bytes = io.BytesIO()
sf.write(audio_bytes, audio_arr, sample_rate, format="WAV")
audio_bytes.seek(0) # Reset buffer position
processed_item = {
"text": item.text,
"name": item.name,
"section": item.section,
"processed": True,
"audio": StreamingResponse(audio_bytes, media_type="audio/wav")
}
processed_items.append(processed_item)
return {"processed_items": processed_items}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)