VITA-Audio / tools /inference_sts.py
shenyunhang's picture
-a
52e4f53
import json
import logging
import os
import random
import re
import sys
import time
import uuid
from threading import Thread
from typing import Optional
import torch
import tqdm
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers.generation import GenerationConfig
import torchaudio
from vita_audio.data.processor.audio_processor import add_audio_input_contiguous
from vita_audio.tokenizer import get_audio_tokenizer
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
torch.manual_seed(1234)
device_map = "cuda:0"
audio_tokenizer_rank = 0
torch_dtype = torch.bfloat16
# model_name_or_path = sys.argv[1]
# audio_tokenizer_path = sys.argv[2]
# flow_path = sys.argv[3]
if True:
# if False:
# sensevoice glm4voice tokenizer
sys.path.append("third_party/GLM-4-Voice/")
sys.path.append("third_party/GLM-4-Voice/cosyvoice/")
sys.path.append("third_party/GLM-4-Voice/third_party/Matcha-TTS/")
audio_tokenizer_path = "/data/models/THUDM/glm-4-voice-tokenizer"
flow_path = "/data/models/THUDM/glm-4-voice-decoder"
audio_tokenizer_type = "sensevoice_glm4voice"
model_name_or_path = "VITA-MLLM/VITA-Audio-Plus-Vanilla/"
# if True:
if False:
# glm4voice tokenizer
sys.path.append("third_party/GLM-4-Voice/")
sys.path.append("third_party/GLM-4-Voice/cosyvoice/")
sys.path.append("third_party/GLM-4-Voice/third_party/Matcha-TTS/")
audio_tokenizer_path = "/data/models/THUDM/glm-4-voice-tokenizer"
flow_path = "/data/models/THUDM/glm-4-voice-decoder"
audio_tokenizer_type = "glm4voice"
# model_name_or_path = "VITA-MLLM/VITA-Audio-Balance"
model_name_or_path = "VITA-MLLM/VITA-Audio-Boost"
output_dir = "/data/output/LM/inference/"
os.makedirs(output_dir, exist_ok=True)
class TextAudioIteratorStreamer(TextIteratorStreamer):
def __init__(
self,
tokenizer: "AutoTokenizer",
skip_prompt: bool = False,
timeout: Optional[float] = None,
**decode_kwargs,
):
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
# self.audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>")
self.audio_offset = tokenizer.convert_tokens_to_ids("<|begin_of_audio|>")
self.num_decode_tokens = 0
def put(self, value):
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
self.num_decode_tokens += len(value)
# Add the new token to the cache and decodes the entire thing.
self.token_cache.extend(value.tolist())
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.token_cache = []
self.print_len = 0
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len :]
self.print_len += len(printable_text)
elif self.token_cache[-1] >= self.audio_offset:
printable_text = text[self.print_len :]
self.print_len += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
self.on_finalized_text(printable_text)
while self.text_queue.qsize() > 10:
time.sleep(0.01)
class BenchmarkIteratorStreamer(TextIteratorStreamer):
def __init__(
self,
tokenizer: "AutoTokenizer",
skip_prompt: bool = False,
timeout: Optional[float] = None,
**decode_kwargs,
):
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
self.num_decode_tokens = 0
def put(self, value):
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
self.num_decode_tokens += len(value)
printable_text = " ".join([str(x) for x in value.tolist()]) + " "
self.on_finalized_text(printable_text)
def find_audio_segments_regex(text):
"""
Find all substrings between <|begin_of_audio|> and <|end_of_audio|> using regex.
Args:
text (str): The input string to search through
Returns:
list: A list of all found audio segments (substrings between the delimiters)
"""
pattern = re.compile(r"<\|begin_of_audio\|>(.*?)<\|end_of_audio\|>", re.DOTALL)
segments = pattern.findall(text)
return [segment.strip() for segment in segments]
def extract_token_ids_as_int(text):
pattern = re.compile(r"<\|audio_(\d+)\|>")
token_ids = pattern.findall(text)
return [int(id) for id in token_ids]
def custom_init_weights(module):
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.constant_(module.bias, 0)
elif isinstance(module, torch.nn.BatchNorm2d) or isinstance(module, torch.nn.BatchNorm1d):
torch.nn.init.constant_(module.weight, 1)
torch.nn.init.constant_(module.bias, 0)
class S2SInference:
def __init__(
self, model_name_or_path, audio_tokenizer_path, audio_tokenizer_type, flow_path=None
):
config = AutoConfig.from_pretrained(
model_name_or_path,
trust_remote_code=True,
)
if "qwen2" in config.model_type.lower():
from evaluation.get_chat_template import qwen2_chat_template as chat_template
add_generation_prompt = True
default_system_message = []
if "hunyuan" in config.model_type.lower():
from evaluation.get_chat_template import hunyuan_chat_template as chat_template
add_generation_prompt = False
default_system_message = [
{
"role": "system",
"content": "You are a helpful AI assistant.",
}
]
luke_system_message = [
{
"role": "system",
"content": "Your Name: Luke\nYour Gender: male\n\nRespond in a text-audio interleaved manner.",
},
]
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True,
chat_template=chat_template,
)
# print(f"{tokenizer=}")
print(f"{tokenizer.get_chat_template()=}")
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch_dtype,
attn_implementation="flash_attention_2",
).eval()
# print("model", model)
print(f"{model.config.model_type=}")
print(f"{model.hf_device_map=}")
model.generation_config = GenerationConfig.from_pretrained(
model_name_or_path, trust_remote_code=True
)
model.generation_config.max_new_tokens = 8192
model.generation_config.chat_format = "chatml"
model.generation_config.max_window_size = 8192
model.generation_config.use_cache = True
# model.generation_config.use_cache = False
model.generation_config.do_sample = False
model.generation_config.temperature = 1.0
model.generation_config.top_k = 50
model.generation_config.top_p = 1.0
model.generation_config.num_beams = 1
model.generation_config.pad_token_id = tokenizer.pad_token_id
if model.config.model_type == "hunyuan":
model.generation_config.eos_token_id = tokenizer.eos_id
print(f"{model.generation_config=}")
audio_tokenizer = get_audio_tokenizer(
audio_tokenizer_path,
audio_tokenizer_type,
flow_path=flow_path,
rank=audio_tokenizer_rank,
)
self.model = model
self.tokenizer = tokenizer
self.audio_tokenizer = audio_tokenizer
self.add_generation_prompt = add_generation_prompt
self.default_system_message = default_system_message
self.luke_system_message = luke_system_message
audio_0_id = tokenizer("<|audio_0|>").input_ids[0]
print(f"{audio_0_id=}")
def benchmark_forward(self, mtp_inference_mode):
print("-" * 100)
print("benchmark_forward...")
print(f"{mtp_inference_mode=}")
total_time = 0
past_key_values = None
use_cache = True
self.model.input_ids = None
self.model.inputs_embeds = None
self.model.hidden_states = [None] * (self.model.config.num_nextn_predict_layers + 1)
self.model.position_ids = None
self.model.attention_mask = None
self.model.mtp_idx = -1
self.model.num_prefill_tokens = -1
model_max_length = 1024
if mtp_inference_mode is not None:
ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode
self.model._prepare_mtp_for_generation(mtp_inference_mode, model_max_length)
else:
self.model._prepare_mtp_for_generation(
self.model.generation_config.mtp_inference_mode, model_max_length
)
for i in tqdm.tqdm(range(1, model_max_length + 1)):
if use_cache:
input_ids = torch.tensor([i - 1], dtype=torch.long).unsqueeze(0).to("cuda")
position_ids = torch.tensor([i - 1], dtype=torch.long).unsqueeze(0).to("cuda")
else:
input_ids = torch.arange(i, dtype=torch.long).unsqueeze(0).to("cuda")
position_ids = torch.arange(i, dtype=torch.long).unsqueeze(0).to("cuda")
attention_mask = torch.tensor([1] * i, dtype=torch.float).unsqueeze(0).to("cuda")
torch.cuda.synchronize()
start = time.time()
output = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
num_logits_to_keep=1,
)
torch.cuda.synchronize()
end = time.time()
total_time += end - start
# print(f"{i=} {total_time=}")
past_key_values = output.past_key_values
print()
print(f"{total_time=}")
print(f"second/token {total_time/model_max_length=}")
print(f"token/second {model_max_length/total_time=}")
if mtp_inference_mode is not None:
self.model.mtp_inference_mode = ori_mtp_inference_mode
def benchmark_generate(self, mtp_inference_mode):
self.model.apply(custom_init_weights)
print("-" * 100)
print("benchmark_generate...")
print(f"{mtp_inference_mode=}")
total_time = 0
self.model.generation_config.use_cache = True
self.model.generation_config.max_new_tokens = 8192
if mtp_inference_mode is not None:
ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode
self.model.generation_config.mtp_inference_mode = mtp_inference_mode
input_ids = torch.tensor([0], dtype=torch.long).unsqueeze(0).to("cuda")
torch.cuda.synchronize()
start = time.time()
output = self.model.generate(
input_ids,
)
# print(f"{output.size()=}")
torch.cuda.synchronize()
end = time.time()
total_time += end - start
print()
print(f"{total_time=}")
print(f"second/token {total_time/output.size(1)=}")
print(f"token/second {output.size(1)/total_time=}")
if mtp_inference_mode is not None:
self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode
def benchmark_generate_stream(self, mtp_inference_mode):
print("-" * 100)
print("benchmark_generate_stream...")
print(f"{mtp_inference_mode=}")
self.model.apply(custom_init_weights)
total_time = 0
self.model.generation_config.use_cache = True
# model_max_length = 8192
model_max_length = 4096
# model_max_length = 2048
# model_max_length = 1024
num_prefill_tokens = 32
self.model.generation_config.max_new_tokens = model_max_length
self.model.generation_config.do_sample = False
if mtp_inference_mode is not None:
ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode
self.model.generation_config.mtp_inference_mode = mtp_inference_mode
input_ids = torch.tensor([0] * num_prefill_tokens, dtype=torch.long).unsqueeze(0).to("cuda")
streamer = BenchmarkIteratorStreamer(self.tokenizer, skip_prompt=True)
generation_kwargs = dict(input_ids=input_ids, streamer=streamer)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
token_decode_time = []
torch.cuda.synchronize()
start = time.time()
thread.start()
generated_text = ""
for new_text in tqdm.tqdm(streamer, total=model_max_length):
generated_text += new_text
end = time.time()
token_decode_time.append(end - start)
yield new_text
# print(f"{len(generated_text)}")
torch.cuda.synchronize()
end = time.time()
total_time += end - start
print()
print(f"{token_decode_time[-1]=}")
print(f"{streamer.num_decode_tokens=}")
print(f"second/token {token_decode_time[-1]/streamer.num_decode_tokens=}")
print(f"token/second {streamer.num_decode_tokens/token_decode_time[-1]=}")
# if mtp_inference_mode is None:
# mtp_inference_mode = []
# with open(f'token_decode_time_{str(mtp_inference_mode)}.json', 'w') as f:
# json.dump(token_decode_time, f)
if mtp_inference_mode is not None:
self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode
def run_infer(
self,
audio_path=None,
prompt_audio_path=None,
stream_stride=4,
max_returned_tokens=4096,
sample_rate=16000,
request_id="",
audio_feats=None,
message="",
use_past=False,
mode="luke",
do_sample=False,
mtp_inference_mode=None,
):
AUD_TAG_TOKEN = "<|audio|>"
AUD_CONTEXT_TOKEN = "<|context_of_audio|>"
AUD_START_TOKEN = "<|begin_of_audio|>"
AUD_END_TOKEN = "<|end_of_audio|>"
if prompt_audio_path is not None:
system_message = [
{
"role": "system",
"content": f"Your Voice: <|audio|>\n",
},
]
elif mode == "luke":
system_message = self.luke_system_message
else:
system_message = self.default_system_message
if prompt_audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True):
# discrete codec
audio_tokens = self.audio_tokenizer.encode(prompt_audio_path)
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens)
system_message[-1]["content"] = system_message[-1]["content"].replace(
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>"
)
if audio_path is not None:
messages = system_message + [
{
"role": "user",
"content": message + "\n<|audio|>",
},
]
else:
messages = system_message + [
{
"role": "user",
"content": message,
},
]
if audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True):
# discrete codec
audio_tokens = self.audio_tokenizer.encode(audio_path)
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens)
messages[-1]["content"] = messages[-1]["content"].replace(
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>"
)
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=self.add_generation_prompt,
)
if (audio_path is not None or prompt_audio_path is not None) and self.audio_tokenizer.apply_to_role(
"user", is_contiguous=True
):
# contiguous codec
audio_paths = []
if audio_path is not None:
audio_paths.append(audio_path)
if prompt_audio_path is not None:
audio_paths.append(prompt_audio_path)
input_ids, audios, audio_indices = add_audio_input_contiguous(
input_ids, audio_paths, self.tokenizer, self.audio_tokenizer
)
else:
audios = None
audio_indices = None
input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda")
print("input", self.tokenizer.decode(input_ids[0], skip_special_tokens=False), flush=True)
self.model.generation_config.do_sample = do_sample
if mtp_inference_mode is not None:
ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode
self.model.generation_config.mtp_inference_mode = mtp_inference_mode
outputs = self.model.generate(
input_ids,
audios=audios,
audio_indices=audio_indices,
)
output = self.tokenizer.decode(outputs[0], skip_special_tokens=False)
print(f"{output=}", flush=True)
audio_offset = self.tokenizer.convert_tokens_to_ids("<|audio_0|>")
audio_tokens = []
for token_id in outputs[0]:
if token_id >= audio_offset:
audio_tokens.append(token_id - audio_offset)
if len(audio_tokens) > 0:
tts_speech = self.audio_tokenizer.decode(
audio_tokens, source_speech_16k=prompt_audio_path
)
else:
tts_speech = None
if mtp_inference_mode is not None:
self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode
return output, tts_speech
def run_infer_stream(
self,
audio_path=None,
prompt_audio_path=None,
stream_stride=4,
max_returned_tokens=4096,
sample_rate=16000,
request_id="",
audio_feats=None,
message="",
use_past=False,
mode="luke",
do_sample=False,
mtp_inference_mode=None,
):
if prompt_audio_path is not None:
system_message = [
{
"role": "system",
"content": f"Your Voice: <|audio|>\n",
},
]
elif mode == "luke":
system_message = self.luke_system_message
else:
system_message = self.default_system_message
if prompt_audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True):
# discrete codec
audio_tokens = self.audio_tokenizer.encode(prompt_audio_path)
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens)
system_message[-1]["content"] = system_message[-1]["content"].replace(
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>"
)
if audio_path is not None:
messages = system_message + [
{
"role": "user",
"content": message + "\n<|audio|>",
},
]
else:
messages = system_message + [
{
"role": "user",
"content": message,
},
]
if audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True):
# discrete codec
audio_tokens = self.audio_tokenizer.encode(audio_path)
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens)
messages[-1]["content"] = messages[-1]["content"].replace(
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>"
)
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=self.add_generation_prompt,
)
if (audio_path is not None or prompt_audio_path is not None) and self.audio_tokenizer.apply_to_role(
"user", is_contiguous=True
):
# contiguous codec
audio_paths = []
if audio_path is not None:
audio_paths.append(audio_path)
if prompt_audio_path is not None:
audio_paths.append(prompt_audio_path)
input_ids, audios, audio_indices = add_audio_input_contiguous(
input_ids, audio_paths, self.tokenizer, self.audio_tokenizer
)
else:
audios = None
audio_indices = None
input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda")
print("input", self.tokenizer.decode(input_ids[0], skip_special_tokens=False), flush=True)
self.model.generation_config.do_sample = do_sample
if mtp_inference_mode is not None:
ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode
self.model.generation_config.mtp_inference_mode = mtp_inference_mode
streamer = TextAudioIteratorStreamer(self.tokenizer, skip_prompt=True)
generation_kwargs = dict(
input_ids=input_ids,
audios=audios,
audio_indices=audio_indices,
streamer=streamer,
)
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# generated_text = ""
for new_text in streamer:
# generated_text += new_text
yield new_text
# torch.cuda.synchronize()
if mtp_inference_mode is not None:
self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode
def benchmark_llm():
for mtp_inference_mode, tag in zip(
[
[8192, 0],
[1, 4, 3, 8, 4, 10],
[1, 10, 4, 10],
[1, 10],
],
[
"Vanilla",
"Balance",
"Boost",
"Turbo",
],
):
print("=" * 100)
print("benchmark_llm")
print(f"{tag}")
s2s_inference.benchmark_forward(mtp_inference_mode)
s2s_inference.benchmark_generate(mtp_inference_mode)
generated_text = ""
for new_text in s2s_inference.benchmark_generate_stream(
mtp_inference_mode=mtp_inference_mode
):
generated_text += new_text
# print(new_text, end="", flush=True)
def benchmark_sts():
audio_paths = [
"asset/介绍一下上海.wav",
"asset/发表一个悲伤的演讲.wav",
"asset/发表一个振奋人心的演讲.wav",
]
for _ in range(10):
print("=" * 100)
print("benchmark_sts")
audio_path = random.choice(audio_paths)
print(f"{audio_path}")
start = time.time()
audio_idx = 0
generated_text = ""
all_tts_speech = []
past_tts_speech_len = 0
for new_text in s2s_inference.run_infer_stream(audio_path=audio_path):
# print(new_text, end="", flush=True)
generated_text += new_text
if new_text == "<|end_of_audio|>":
audio_tokens = extract_token_ids_as_int(generated_text)
tts_speech = s2s_inference.audio_tokenizer.decode(audio_tokens, option_steps=1)
tts_speech = tts_speech[past_tts_speech_len:]
past_tts_speech_len += len(tts_speech)
all_tts_speech.append(tts_speech)
end = time.time()
if audio_idx == 0:
print(audio_tokens)
print(f"{audio_idx} audio chunk {end - start}")
wav_path = os.path.join(output_dir, audio_path[:-4] + f"_{audio_idx}.wav")
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
audio_idx += 1
start = time.time()
wav_path = os.path.join(output_dir, audio_path[:-4] + ".wav")
tts_speech = torch.cat(all_tts_speech, dim=0)
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
# ==============================================================
# Text
def text_task():
for text in [
"How many helicopters can a human eat in one sitting?",
"你叫什么名字?",
"写一首诗",
"介绍一下上海",
]:
print("=" * 100)
print("text_task")
print(f"{text=}")
output, _ = s2s_inference.run_infer(
message=text,
mode=None,
# do_sample=True,
mtp_inference_mode=[8192, 0],
)
print(f"{output=}", flush=True)
# ==============================================================
# Text stream
def text_stream_task():
for text in [
"你叫什么名字?",
]:
print("=" * 100)
print("text_stream_task")
print(f"{text=}")
generated_text = ""
for new_text in s2s_inference.run_infer_stream(
message=text,
mode=None,
# do_sample=True,
mtp_inference_mode=[8192, 0],
):
generated_text += new_text
print(new_text, end="")
print("")
# ==============================================================
# S2S
def sts_task():
for audio_path in [
"asset/介绍一下上海.wav",
"asset/发表一个悲伤的演讲.wav",
"asset/发表一个振奋人心的演讲.wav",
"asset/piano.mp3",
]:
print("=" * 100)
print("sts_task")
print(f"{audio_path=}")
output, tts_speech = s2s_inference.run_infer(
audio_path=audio_path,
)
wav_path = os.path.join(output_dir, audio_path[:-4] + ".wav")
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
# ==============================================================
# S2S stream
def sts_stream_task():
for audio_path in [
"asset/介绍一下上海.wav",
]:
print("=" * 100)
print("sts_stream_task")
print(f"{audio_path=}")
generated_text = ""
for new_text in s2s_inference.run_infer_stream(audio_path=audio_path):
generated_text += new_text
print(new_text, end="")
print("")
audio_decode_time = []
audio_segments = find_audio_segments_regex(generated_text)
for audio_idx, audio_segment in enumerate(audio_segments):
start = time.time()
audio_tokens = extract_token_ids_as_int(audio_segment)
# print(audio_tokens)
tts_speech = s2s_inference.audio_tokenizer.decode(audio_tokens)
end = time.time()
audio_decode_time.append(end - start)
wav_path = os.path.join(output_dir, audio_path[:-4] + f"_{audio_idx}.wav")
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
# print(f"{audio_decode_time=}")
# ==============================================================
# ASR
def asr_task():
for audio_path in [
"/data/data/wenet-e2e/wenetspeech/data/cuts_TEST_NET.00000000/TES/TEST_NET_Y0000000020_5XD21BihDd8_S00395.wav",
"/data/data/wenet-e2e/wenetspeech/data/cuts_TEST_NET.00000000/TES/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00424.wav",
"/data/data/wenet-e2e/wenetspeech/data/cuts_TEST_NET.00000000/TES/TEST_NET_Y0000000050_LOLTeK1BNMo_S00045.wav",
"/data/data/fixie-ai/librispeech_asr/test.clean/2830-3980-0034.wav",
"/data/data/fixie-ai/librispeech_asr/test.clean/237-134500-0040.wav",
]:
print("=" * 100)
print("asr_task")
print(f"{audio_path=}")
output, tts_speech = s2s_inference.run_infer(
audio_path=audio_path,
# message="Translate the speech to text.",
message="Convert the speech to text.",
mode=None,
)
print(f"{output=}", flush=True)
# ==============================================================
# TTS
def tts_task():
TTS_texts = [
"我们将为全球城市的可持续发展贡献力量。",
"通天河 灵感大王",
"他本是我莲花池里养大的金鱼,每日浮头听经,修成手段。那一柄九瓣铜锤,乃是一枝未开的菡萏,被他运炼成兵。不知是那一日,海潮泛涨,走到此间。我今早扶栏看花,却不见这厮出拜,掐指巡纹,算着他在此成精,害你师父,故此未及梳妆,运神功,织个竹篮儿擒他。",
"一二三四五六七八九十",
"One Two Tree Four Five Six Seven Eight Night Ten",
"1 2 3 4 5 6 7 8 9 10",
"12345678910",
"两个黄鹂鸣翠柳,一行白鹭上青天。窗含西岭千秋雪,门泊东吴万里船。",
"坡上立着一只鹅,坡下就是一条河。宽宽的河,肥肥的鹅,鹅要过河,河要渡鹅不知是鹅过河,还是河渡鹅?",
"扁担长,板凳宽,扁担没有板凳宽,板凳没有扁担长。扁担绑在板凳上,板凳不让扁担绑在板凳上。",
"化肥会挥发,黑化肥发灰,灰化肥发黑。黑化肥发灰会挥发;灰化肥挥发会发黑。黑化肥挥发发灰会花飞;灰化肥挥发发黑会飞花,黑灰化肥会挥发发灰黑讳为花飞;灰黑化肥会挥发发黑灰为讳飞花。",
"圆桌儿、方桌儿没有腿儿,墨水瓶儿里没有水儿,花瓶里有花儿没有叶儿,练习本儿上写字儿没有准儿,甘蔗好吃净是节儿。西瓜挺大没有味儿,坛儿里的小米儿长了虫儿,鸡毛掸子成了棍儿,水缸沿儿上系围裙儿,耗子打更猫打盹儿,新买的小褂儿没钉扣儿,奶奶想说没有劲儿。",
"起床歌:小宝宝,起得早,睁开眼,眯眯笑,咿呀呀,学说话,伸伸手,要人抱。穿衣歌小胳膊,穿袖子,穿上衣,扣扣子,小脚丫,穿裤子,穿上袜子穿鞋子。小镜子-小镜子,圆又圆,看宝宝,露笑脸。闭上眼,做个梦,变月亮,挂上天。小铃铛叮铃铃,叮铃铃,一会远,一会近。小宝宝,耳朵灵,听铃声,找到铃。学画画小宝宝,学画画,大蜡笔,手中拿,画小鸭,叫嘎嘎,画小马,骑回家。大鞋子大鞋子,像只船,爸爸穿,我也穿,一二一,向前走,走呀走,翻了船。逛公园逛公园,宝宝笑,东看看,西瞧瞧,花儿香,鸟儿叫,小草绿,小树摇。看画报小娃娃,看画报,睁大眼,仔细瞧,布娃娃,哈哈笑,伸伸手,要你抱。搭积木大积木,红黄兰,小宝宝,最爱玩,搭火车,钻山洞,盖高楼,连着天。小汽车小汽车,嘀嘀嘀,开过来,开过去,小宝宝,当司机,送妈妈,上班去。藏猫猫儿歌:躲猫猫,躲猫猫, 猫猫、猫猫在哪里?喵……猫咪在这里。",
]
for text in TTS_texts:
print("=" * 100)
print("tts_task")
print(f"{text=}")
output, tts_speech = s2s_inference.run_infer(
message="Convert the text to speech.\n" + text,
mode=None,
do_sample=True,
)
wav_path = os.path.join(output_dir, text[:16] + ".wav")
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
# ==============================================================
# Clone TTS
for text in TTS_texts:
for prompt_audio_path in [
"asset/2631296891109983590.wav",
"asset/379838640-d5ff0815-74f8-4738-b0f1-477cfc8dcc2d.wav",
"asset/4202818730519913143.wav",
]:
print("=" * 100)
print("tts_task")
print(f"{text=} {prompt_audio_path=}")
output, tts_speech = s2s_inference.run_infer(
prompt_audio_path=prompt_audio_path,
# message="Translate the text to speech.\n" + text,
message="Convert the text to speech.\n" + text,
mode=None,
do_sample=True,
)
wav_path = os.path.join(output_dir, prompt_audio_path[:16] + "_" + text[:16] + ".wav")
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
# ==============================================================
# TTS stream
def tts_stream_task():
TTS_texts = [
"他本是我莲花池里养大的金鱼,每日浮头听经,修成手段。那一柄九瓣铜锤,乃是一枝未开的菡萏,被他运炼成兵。不知是那一日,海潮泛涨,走到此间。我今早扶栏看花,却不见这厮出拜,掐指巡纹,算着他在此成精,害你师父,故此未及梳妆,运神功,织个竹篮儿擒他。",
]
for text in TTS_texts:
print("=" * 100)
print("tts_stream_task")
print(f"{text=}")
generated_text = ""
for new_text in s2s_inference.run_infer_stream(
message="Convert the text to speech.\n" + text,
mode=None,
do_sample=True,
):
generated_text += new_text
print(new_text, end="")
print("")
audio_segments = find_audio_segments_regex(generated_text)
for audio_idx, audio_segment in enumerate(audio_segments):
audio_tokens = extract_token_ids_as_int(audio_segment)
# print(audio_tokens)
tts_speech = s2s_inference.audio_tokenizer.decode(audio_tokens)
wav_path = os.path.join(output_dir, text[:16] + f"_{audio_idx}.wav")
os.makedirs(os.path.dirname(wav_path), exist_ok=True)
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav")
s2s_inference = S2SInference(
model_name_or_path, audio_tokenizer_path, audio_tokenizer_type, flow_path=flow_path
)
text_task()
text_stream_task()
sts_task()
sts_stream_task()
asr_task()
tts_task()
tts_stream_task()
benchmark_sts()
benchmark_llm()