File size: 1,987 Bytes
cfe276e
 
e7fd9f7
cbf8a35
 
 
 
 
626684c
eb10866
 
cbf8a35
 
 
 
 
 
 
 
3c6df0d
cbf8a35
 
 
3c6df0d
 
 
cbf8a35
 
 
 
 
 
 
 
 
 
 
 
 
3c6df0d
cbf8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os

os.environ["COQUI_TOS_AGREED"] = "1"
from TTS.api import TTS
from TTS.utils.manage import ModelManager
from TTS.utils.generic_utils import get_user_data_dir
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import torch
import time
import torchaudio
import io
import base64


class EndpointHandler:
    def __init__(self, path=""):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        config = XttsConfig()
        config.load_json("/repository/model/config.json")
        model = Xtts.init_from_config(config)
        model.load_checkpoint(
            config,
            checkpoint_path="/repository/model/model.pth",
            vocab_path="/repository/model/vocab.json",
            speaker_file_path="/repository/model/speakers_xtts.pth",
            eval=True,
            use_deepspeed=device == "cuda",
        )
        model.to(device)

        self.model = model

    def __call__(self, model_input):

        (
            gpt_cond_latent,
            speaker_embedding,
        ) = self.model.get_conditioning_latents(
            audio_path="/repository/attenborough.mp3",
            gpt_cond_len=30,
            gpt_cond_chunk_len=4,
            max_ref_length=60,
        )

        print("Generating audio")
        t0 = time.time()
        out = self.model.inference(
            text=model_input["text"],
            speaker_embedding=speaker_embedding,
            gpt_cond_latent=gpt_cond_latent,
            temperature=0.75,
            repetition_penalty=2.5,
            language="en",
            enable_text_splitting=True,
        )
        print(f"I: Time to generate audio: {inference_time} seconds")
        audio_file = io.BytesIO()
        torchaudio.save(audio_file, torch.tensor(out["wav"]).unsqueeze(0), 24000)
        inference_time = time.time() - t0
        audio_str = base64.b64encode(audio_file.getvalue()).decode("utf-8")
        return {"data": audio_str, "format": "wav"}