File size: 2,422 Bytes
dc5bc59 a168b4f 910d316 dc5bc59 a168b4f dc5bc59 a168b4f 910d316 a168b4f dc5bc59 a168b4f 910d316 5dfce18 910d316 5dfce18 910d316 5dfce18 910d316 5dfce18 910d316 5dfce18 910d316 a168b4f 910d316 a168b4f 5dfce18 910d316 5dfce18 910d316 5dfce18 910d316 5dfce18 910d316 a168b4f 910d316 5dfce18 a168b4f 910d316 a168b4f 910d316 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
import io
from starlette.responses import StreamingResponse
# Initialize the FastAPI app
app = FastAPI()
# Define a Pydantic model for the items
class Item(BaseModel):
text: str
name: str
section: str
# Initialize ParlerTTS
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
# A simple GET endpoint
@app.get("/")
def greet_json():
return {"Hello": "World!"}
# Function to generate audio from text using ParlerTTS
def generate_audio(text, description="Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."):
print("A")
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
print("B")
prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
print("C")
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
print("D")
audio_arr = generation.cpu().numpy().squeeze()
print("E")
return audio_arr, model.config.sampling_rate
# A POST endpoint to receive and parse an array of JSON objects
@app.post("/")
async def create_items(items: List[Item]):
processed_items = []
for item in items:
print(f"Processing item: {item.text}")
# Generate audio
print("before")
audio_arr, sample_rate = generate_audio(item.text)
print("after")
# # Create in-memory bytes buffer for audio
# audio_bytes = io.BytesIO()
# sf.write(audio_bytes, audio_arr, sample_rate, format="WAV")
# audio_bytes.seek(0) # Reset buffer position
processed_item = {
"text": item.text,
"name": item.name,
"section": item.section,
"processed": True,
# "audio": StreamingResponse(audio_bytes, media_type="audio/wav")
}
processed_items.append(processed_item)
return {"processed_items": processed_items}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)
|