File size: 2,422 Bytes
dc5bc59
a168b4f
 
910d316
 
 
 
 
 
dc5bc59
a168b4f
dc5bc59
 
a168b4f
 
 
 
 
 
910d316
 
 
 
 
a168b4f
dc5bc59
 
 
a168b4f
910d316
5dfce18
 
910d316
5dfce18
910d316
5dfce18
910d316
5dfce18
910d316
5dfce18
910d316
 
a168b4f
 
910d316
a168b4f
 
5dfce18
910d316
5dfce18
910d316
5dfce18
910d316
5dfce18
 
 
 
910d316
a168b4f
 
 
 
910d316
5dfce18
a168b4f
 
910d316
a168b4f
910d316
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
import io
from starlette.responses import StreamingResponse

# Initialize the FastAPI app
app = FastAPI()

# Define a Pydantic model for the items
class Item(BaseModel):
    text: str
    name: str
    section: str

# Initialize ParlerTTS
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")

# A simple GET endpoint
@app.get("/")
def greet_json():
    return {"Hello": "World!"}

# Function to generate audio from text using ParlerTTS
def generate_audio(text, description="Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."):
    print("A")
    input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
    print("B")
    prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
    print("C")
    generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
    print("D")
    audio_arr = generation.cpu().numpy().squeeze()
    print("E")
    return audio_arr, model.config.sampling_rate

# A POST endpoint to receive and parse an array of JSON objects
@app.post("/")
async def create_items(items: List[Item]):
    processed_items = []
    for item in items:
        print(f"Processing item: {item.text}")
        # Generate audio
        print("before")
        audio_arr, sample_rate = generate_audio(item.text)
        print("after")

        # # Create in-memory bytes buffer for audio
        # audio_bytes = io.BytesIO()
        # sf.write(audio_bytes, audio_arr, sample_rate, format="WAV")
        # audio_bytes.seek(0)  # Reset buffer position

        processed_item = {
            "text": item.text,
            "name": item.name,
            "section": item.section,
            "processed": True,
            # "audio": StreamingResponse(audio_bytes, media_type="audio/wav")
        }
        processed_items.append(processed_item)

    return {"processed_items": processed_items}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="127.0.0.1", port=8000)