File size: 897 Bytes
01e655b 02e90e4 01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import numpy as np
import torch
from modules.utils.SeedContext import SeedContext
from modules import models, config
@torch.inference_mode()
def refine_text(
text: str,
prompt="[oral_2][laugh_0][break_6]",
seed=-1,
top_P=0.7,
top_K=20,
temperature=0.7,
repetition_penalty=1.0,
max_new_token=384,
) -> str:
chat_tts = models.load_chat_tts()
with SeedContext(seed):
refined_text = chat_tts.refiner_prompt(
text,
{
"prompt": prompt,
"top_K": top_K,
"top_P": top_P,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_new_token": max_new_token,
"disable_tqdm": config.runtime_env_vars.off_tqdm,
},
do_text_normalization=False,
)
return refined_text
|