simdi commited on
Commit
cbf8a35
1 Parent(s): 41a54f3

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +61 -0
handler.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from TTS.api import TTS
3
+ from TTS.utils.manage import ModelManager
4
+ from TTS.utils.generic_utils import get_user_data_dir
5
+ import torch
6
+ import os
7
+ from TTS.tts.configs.xtts_config import XttsConfig
8
+ import torchaudio
9
+ from TTS.tts.models.xtts import Xtts
10
+ import io
11
+ import base64
12
+
13
+
14
+ class EndpointHandler:
15
+ def __init__(self, path=""):
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ config = XttsConfig()
18
+ config.load_json("./model/config.json")
19
+ model = Xtts.init_from_config(config)
20
+ model.load_checkpoint(
21
+ config,
22
+ checkpoint_path="./model/model.pth",
23
+ vocab_path="./model/vocab.json",
24
+ speaker_file_path="./model/speakers_xtts.pth",
25
+ eval=True,
26
+ use_deepspeed=device == "cuda",
27
+ )
28
+ model.to(device)
29
+
30
+ self.model = model
31
+
32
+ def __call__(self, model_input):
33
+
34
+ (
35
+ gpt_cond_latent,
36
+ speaker_embedding,
37
+ ) = self.model.get_conditioning_latents(
38
+ audio_path="attenborough.mp3",
39
+ gpt_cond_len=30,
40
+ gpt_cond_chunk_len=4,
41
+ max_ref_length=60,
42
+ )
43
+
44
+ print("Generating audio")
45
+ t0 = time.time()
46
+ out = self.model.inference(
47
+ text=model_input["text"],
48
+ speaker_embedding=speaker_embedding,
49
+ gpt_cond_latent=gpt_cond_latent,
50
+ temperature=0.75,
51
+ repetition_penalty=2.5,
52
+ language="en",
53
+ enable_text_splitting=True,
54
+ )
55
+ print(f"I: Time to generate audio: {inference_time} seconds")
56
+ audio_file = io.BytesIO()
57
+ torchaudio.save(audio_file, torch.tensor(out["wav"]).unsqueeze(0), 24000)
58
+ inference_time = time.time() - t0
59
+ audio_str = base64.b64encode(audio_file.getvalue()).decode("utf-8")
60
+ return {"data": audio_str, "format": "wav"}
61
+