|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
from typing import List |
|
import torch |
|
from parler_tts import ParlerTTSForConditionalGeneration |
|
from transformers import AutoTokenizer |
|
import soundfile as sf |
|
import io |
|
from starlette.responses import StreamingResponse |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
class Item(BaseModel): |
|
text: str |
|
name: str |
|
section: str |
|
|
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device) |
|
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1") |
|
|
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "World!"} |
|
|
|
|
|
def generate_audio(text, description="Neutral voice"): |
|
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device) |
|
prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device) |
|
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) |
|
audio_arr = generation.cpu().numpy().squeeze() |
|
return audio_arr, model.config.sampling_rate |
|
|
|
|
|
@app.post("/") |
|
async def create_items(items: List[Item]): |
|
processed_items = [] |
|
for item in items: |
|
|
|
audio_arr, sample_rate = generate_audio(item.text) |
|
|
|
|
|
audio_bytes = io.BytesIO() |
|
sf.write(audio_bytes, audio_arr, sample_rate, format="WAV") |
|
audio_bytes.seek(0) |
|
|
|
processed_item = { |
|
"text": item.text, |
|
"name": item.name, |
|
"section": item.section, |
|
"processed": True, |
|
"audio": StreamingResponse(audio_bytes, media_type="audio/wav") |
|
} |
|
processed_items.append(processed_item) |
|
|
|
return {"processed_items": processed_items} |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="127.0.0.1", port=8000) |
|
|