Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import logging | |
import os | |
import random | |
import re | |
import sys | |
import time | |
import uuid | |
from threading import Thread | |
from typing import Optional | |
import torch | |
import tqdm | |
from torch import nn | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from transformers.generation import GenerationConfig | |
import torchaudio | |
from vita_audio.data.processor.audio_processor import add_audio_input_contiguous | |
from vita_audio.tokenizer import get_audio_tokenizer | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
torch.manual_seed(1234) | |
device_map = "cuda:0" | |
audio_tokenizer_rank = 0 | |
torch_dtype = torch.bfloat16 | |
# model_name_or_path = sys.argv[1] | |
# audio_tokenizer_path = sys.argv[2] | |
# flow_path = sys.argv[3] | |
if True: | |
# if False: | |
# sensevoice glm4voice tokenizer | |
sys.path.append("third_party/GLM-4-Voice/") | |
sys.path.append("third_party/GLM-4-Voice/cosyvoice/") | |
sys.path.append("third_party/GLM-4-Voice/third_party/Matcha-TTS/") | |
audio_tokenizer_path = "/data/models/THUDM/glm-4-voice-tokenizer" | |
flow_path = "/data/models/THUDM/glm-4-voice-decoder" | |
audio_tokenizer_type = "sensevoice_glm4voice" | |
model_name_or_path = "VITA-MLLM/VITA-Audio-Plus-Vanilla/" | |
# if True: | |
if False: | |
# glm4voice tokenizer | |
sys.path.append("third_party/GLM-4-Voice/") | |
sys.path.append("third_party/GLM-4-Voice/cosyvoice/") | |
sys.path.append("third_party/GLM-4-Voice/third_party/Matcha-TTS/") | |
audio_tokenizer_path = "/data/models/THUDM/glm-4-voice-tokenizer" | |
flow_path = "/data/models/THUDM/glm-4-voice-decoder" | |
audio_tokenizer_type = "glm4voice" | |
# model_name_or_path = "VITA-MLLM/VITA-Audio-Balance" | |
model_name_or_path = "VITA-MLLM/VITA-Audio-Boost" | |
output_dir = "/data/output/LM/inference/" | |
os.makedirs(output_dir, exist_ok=True) | |
class TextAudioIteratorStreamer(TextIteratorStreamer): | |
def __init__( | |
self, | |
tokenizer: "AutoTokenizer", | |
skip_prompt: bool = False, | |
timeout: Optional[float] = None, | |
**decode_kwargs, | |
): | |
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) | |
# self.audio_offset = tokenizer.convert_tokens_to_ids("<|audio_0|>") | |
self.audio_offset = tokenizer.convert_tokens_to_ids("<|begin_of_audio|>") | |
self.num_decode_tokens = 0 | |
def put(self, value): | |
""" | |
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. | |
""" | |
if len(value.shape) > 1 and value.shape[0] > 1: | |
raise ValueError("TextStreamer only supports batch size 1") | |
elif len(value.shape) > 1: | |
value = value[0] | |
if self.skip_prompt and self.next_tokens_are_prompt: | |
self.next_tokens_are_prompt = False | |
return | |
self.num_decode_tokens += len(value) | |
# Add the new token to the cache and decodes the entire thing. | |
self.token_cache.extend(value.tolist()) | |
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) | |
# After the symbol for a new line, we flush the cache. | |
if text.endswith("\n"): | |
printable_text = text[self.print_len :] | |
self.token_cache = [] | |
self.print_len = 0 | |
# If the last token is a CJK character, we print the characters. | |
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): | |
printable_text = text[self.print_len :] | |
self.print_len += len(printable_text) | |
elif self.token_cache[-1] >= self.audio_offset: | |
printable_text = text[self.print_len :] | |
self.print_len += len(printable_text) | |
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words, | |
# which may change with the subsequent token -- there are probably smarter ways to do this!) | |
else: | |
printable_text = text[self.print_len : text.rfind(" ") + 1] | |
self.print_len += len(printable_text) | |
self.on_finalized_text(printable_text) | |
while self.text_queue.qsize() > 10: | |
time.sleep(0.01) | |
class BenchmarkIteratorStreamer(TextIteratorStreamer): | |
def __init__( | |
self, | |
tokenizer: "AutoTokenizer", | |
skip_prompt: bool = False, | |
timeout: Optional[float] = None, | |
**decode_kwargs, | |
): | |
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) | |
self.num_decode_tokens = 0 | |
def put(self, value): | |
""" | |
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. | |
""" | |
if len(value.shape) > 1 and value.shape[0] > 1: | |
raise ValueError("TextStreamer only supports batch size 1") | |
elif len(value.shape) > 1: | |
value = value[0] | |
if self.skip_prompt and self.next_tokens_are_prompt: | |
self.next_tokens_are_prompt = False | |
return | |
self.num_decode_tokens += len(value) | |
printable_text = " ".join([str(x) for x in value.tolist()]) + " " | |
self.on_finalized_text(printable_text) | |
def find_audio_segments_regex(text): | |
""" | |
Find all substrings between <|begin_of_audio|> and <|end_of_audio|> using regex. | |
Args: | |
text (str): The input string to search through | |
Returns: | |
list: A list of all found audio segments (substrings between the delimiters) | |
""" | |
pattern = re.compile(r"<\|begin_of_audio\|>(.*?)<\|end_of_audio\|>", re.DOTALL) | |
segments = pattern.findall(text) | |
return [segment.strip() for segment in segments] | |
def extract_token_ids_as_int(text): | |
pattern = re.compile(r"<\|audio_(\d+)\|>") | |
token_ids = pattern.findall(text) | |
return [int(id) for id in token_ids] | |
def custom_init_weights(module): | |
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): | |
torch.nn.init.xavier_uniform_(module.weight) | |
if module.bias is not None: | |
torch.nn.init.constant_(module.bias, 0) | |
elif isinstance(module, torch.nn.BatchNorm2d) or isinstance(module, torch.nn.BatchNorm1d): | |
torch.nn.init.constant_(module.weight, 1) | |
torch.nn.init.constant_(module.bias, 0) | |
class S2SInference: | |
def __init__( | |
self, model_name_or_path, audio_tokenizer_path, audio_tokenizer_type, flow_path=None | |
): | |
config = AutoConfig.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True, | |
) | |
if "qwen2" in config.model_type.lower(): | |
from evaluation.get_chat_template import qwen2_chat_template as chat_template | |
add_generation_prompt = True | |
default_system_message = [] | |
if "hunyuan" in config.model_type.lower(): | |
from evaluation.get_chat_template import hunyuan_chat_template as chat_template | |
add_generation_prompt = False | |
default_system_message = [ | |
{ | |
"role": "system", | |
"content": "You are a helpful AI assistant.", | |
} | |
] | |
luke_system_message = [ | |
{ | |
"role": "system", | |
"content": "Your Name: Luke\nYour Gender: male\n\nRespond in a text-audio interleaved manner.", | |
}, | |
] | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True, | |
chat_template=chat_template, | |
) | |
# print(f"{tokenizer=}") | |
print(f"{tokenizer.get_chat_template()=}") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
trust_remote_code=True, | |
device_map=device_map, | |
torch_dtype=torch_dtype, | |
attn_implementation="flash_attention_2", | |
).eval() | |
# print("model", model) | |
print(f"{model.config.model_type=}") | |
print(f"{model.hf_device_map=}") | |
model.generation_config = GenerationConfig.from_pretrained( | |
model_name_or_path, trust_remote_code=True | |
) | |
model.generation_config.max_new_tokens = 8192 | |
model.generation_config.chat_format = "chatml" | |
model.generation_config.max_window_size = 8192 | |
model.generation_config.use_cache = True | |
# model.generation_config.use_cache = False | |
model.generation_config.do_sample = False | |
model.generation_config.temperature = 1.0 | |
model.generation_config.top_k = 50 | |
model.generation_config.top_p = 1.0 | |
model.generation_config.num_beams = 1 | |
model.generation_config.pad_token_id = tokenizer.pad_token_id | |
if model.config.model_type == "hunyuan": | |
model.generation_config.eos_token_id = tokenizer.eos_id | |
print(f"{model.generation_config=}") | |
audio_tokenizer = get_audio_tokenizer( | |
audio_tokenizer_path, | |
audio_tokenizer_type, | |
flow_path=flow_path, | |
rank=audio_tokenizer_rank, | |
) | |
self.model = model | |
self.tokenizer = tokenizer | |
self.audio_tokenizer = audio_tokenizer | |
self.add_generation_prompt = add_generation_prompt | |
self.default_system_message = default_system_message | |
self.luke_system_message = luke_system_message | |
audio_0_id = tokenizer("<|audio_0|>").input_ids[0] | |
print(f"{audio_0_id=}") | |
def benchmark_forward(self, mtp_inference_mode): | |
print("-" * 100) | |
print("benchmark_forward...") | |
print(f"{mtp_inference_mode=}") | |
total_time = 0 | |
past_key_values = None | |
use_cache = True | |
self.model.input_ids = None | |
self.model.inputs_embeds = None | |
self.model.hidden_states = [None] * (self.model.config.num_nextn_predict_layers + 1) | |
self.model.position_ids = None | |
self.model.attention_mask = None | |
self.model.mtp_idx = -1 | |
self.model.num_prefill_tokens = -1 | |
model_max_length = 1024 | |
if mtp_inference_mode is not None: | |
ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode | |
self.model._prepare_mtp_for_generation(mtp_inference_mode, model_max_length) | |
else: | |
self.model._prepare_mtp_for_generation( | |
self.model.generation_config.mtp_inference_mode, model_max_length | |
) | |
for i in tqdm.tqdm(range(1, model_max_length + 1)): | |
if use_cache: | |
input_ids = torch.tensor([i - 1], dtype=torch.long).unsqueeze(0).to("cuda") | |
position_ids = torch.tensor([i - 1], dtype=torch.long).unsqueeze(0).to("cuda") | |
else: | |
input_ids = torch.arange(i, dtype=torch.long).unsqueeze(0).to("cuda") | |
position_ids = torch.arange(i, dtype=torch.long).unsqueeze(0).to("cuda") | |
attention_mask = torch.tensor([1] * i, dtype=torch.float).unsqueeze(0).to("cuda") | |
torch.cuda.synchronize() | |
start = time.time() | |
output = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
use_cache=use_cache, | |
num_logits_to_keep=1, | |
) | |
torch.cuda.synchronize() | |
end = time.time() | |
total_time += end - start | |
# print(f"{i=} {total_time=}") | |
past_key_values = output.past_key_values | |
print() | |
print(f"{total_time=}") | |
print(f"second/token {total_time/model_max_length=}") | |
print(f"token/second {model_max_length/total_time=}") | |
if mtp_inference_mode is not None: | |
self.model.mtp_inference_mode = ori_mtp_inference_mode | |
def benchmark_generate(self, mtp_inference_mode): | |
self.model.apply(custom_init_weights) | |
print("-" * 100) | |
print("benchmark_generate...") | |
print(f"{mtp_inference_mode=}") | |
total_time = 0 | |
self.model.generation_config.use_cache = True | |
self.model.generation_config.max_new_tokens = 8192 | |
if mtp_inference_mode is not None: | |
ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode | |
self.model.generation_config.mtp_inference_mode = mtp_inference_mode | |
input_ids = torch.tensor([0], dtype=torch.long).unsqueeze(0).to("cuda") | |
torch.cuda.synchronize() | |
start = time.time() | |
output = self.model.generate( | |
input_ids, | |
) | |
# print(f"{output.size()=}") | |
torch.cuda.synchronize() | |
end = time.time() | |
total_time += end - start | |
print() | |
print(f"{total_time=}") | |
print(f"second/token {total_time/output.size(1)=}") | |
print(f"token/second {output.size(1)/total_time=}") | |
if mtp_inference_mode is not None: | |
self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode | |
def benchmark_generate_stream(self, mtp_inference_mode): | |
print("-" * 100) | |
print("benchmark_generate_stream...") | |
print(f"{mtp_inference_mode=}") | |
self.model.apply(custom_init_weights) | |
total_time = 0 | |
self.model.generation_config.use_cache = True | |
# model_max_length = 8192 | |
model_max_length = 4096 | |
# model_max_length = 2048 | |
# model_max_length = 1024 | |
num_prefill_tokens = 32 | |
self.model.generation_config.max_new_tokens = model_max_length | |
self.model.generation_config.do_sample = False | |
if mtp_inference_mode is not None: | |
ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode | |
self.model.generation_config.mtp_inference_mode = mtp_inference_mode | |
input_ids = torch.tensor([0] * num_prefill_tokens, dtype=torch.long).unsqueeze(0).to("cuda") | |
streamer = BenchmarkIteratorStreamer(self.tokenizer, skip_prompt=True) | |
generation_kwargs = dict(input_ids=input_ids, streamer=streamer) | |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
token_decode_time = [] | |
torch.cuda.synchronize() | |
start = time.time() | |
thread.start() | |
generated_text = "" | |
for new_text in tqdm.tqdm(streamer, total=model_max_length): | |
generated_text += new_text | |
end = time.time() | |
token_decode_time.append(end - start) | |
yield new_text | |
# print(f"{len(generated_text)}") | |
torch.cuda.synchronize() | |
end = time.time() | |
total_time += end - start | |
print() | |
print(f"{token_decode_time[-1]=}") | |
print(f"{streamer.num_decode_tokens=}") | |
print(f"second/token {token_decode_time[-1]/streamer.num_decode_tokens=}") | |
print(f"token/second {streamer.num_decode_tokens/token_decode_time[-1]=}") | |
# if mtp_inference_mode is None: | |
# mtp_inference_mode = [] | |
# with open(f'token_decode_time_{str(mtp_inference_mode)}.json', 'w') as f: | |
# json.dump(token_decode_time, f) | |
if mtp_inference_mode is not None: | |
self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode | |
def run_infer( | |
self, | |
audio_path=None, | |
prompt_audio_path=None, | |
stream_stride=4, | |
max_returned_tokens=4096, | |
sample_rate=16000, | |
request_id="", | |
audio_feats=None, | |
message="", | |
use_past=False, | |
mode="luke", | |
do_sample=False, | |
mtp_inference_mode=None, | |
): | |
AUD_TAG_TOKEN = "<|audio|>" | |
AUD_CONTEXT_TOKEN = "<|context_of_audio|>" | |
AUD_START_TOKEN = "<|begin_of_audio|>" | |
AUD_END_TOKEN = "<|end_of_audio|>" | |
if prompt_audio_path is not None: | |
system_message = [ | |
{ | |
"role": "system", | |
"content": f"Your Voice: <|audio|>\n", | |
}, | |
] | |
elif mode == "luke": | |
system_message = self.luke_system_message | |
else: | |
system_message = self.default_system_message | |
if prompt_audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True): | |
# discrete codec | |
audio_tokens = self.audio_tokenizer.encode(prompt_audio_path) | |
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) | |
system_message[-1]["content"] = system_message[-1]["content"].replace( | |
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" | |
) | |
if audio_path is not None: | |
messages = system_message + [ | |
{ | |
"role": "user", | |
"content": message + "\n<|audio|>", | |
}, | |
] | |
else: | |
messages = system_message + [ | |
{ | |
"role": "user", | |
"content": message, | |
}, | |
] | |
if audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True): | |
# discrete codec | |
audio_tokens = self.audio_tokenizer.encode(audio_path) | |
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) | |
messages[-1]["content"] = messages[-1]["content"].replace( | |
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" | |
) | |
input_ids = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=self.add_generation_prompt, | |
) | |
if (audio_path is not None or prompt_audio_path is not None) and self.audio_tokenizer.apply_to_role( | |
"user", is_contiguous=True | |
): | |
# contiguous codec | |
audio_paths = [] | |
if audio_path is not None: | |
audio_paths.append(audio_path) | |
if prompt_audio_path is not None: | |
audio_paths.append(prompt_audio_path) | |
input_ids, audios, audio_indices = add_audio_input_contiguous( | |
input_ids, audio_paths, self.tokenizer, self.audio_tokenizer | |
) | |
else: | |
audios = None | |
audio_indices = None | |
input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda") | |
print("input", self.tokenizer.decode(input_ids[0], skip_special_tokens=False), flush=True) | |
self.model.generation_config.do_sample = do_sample | |
if mtp_inference_mode is not None: | |
ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode | |
self.model.generation_config.mtp_inference_mode = mtp_inference_mode | |
outputs = self.model.generate( | |
input_ids, | |
audios=audios, | |
audio_indices=audio_indices, | |
) | |
output = self.tokenizer.decode(outputs[0], skip_special_tokens=False) | |
print(f"{output=}", flush=True) | |
audio_offset = self.tokenizer.convert_tokens_to_ids("<|audio_0|>") | |
audio_tokens = [] | |
for token_id in outputs[0]: | |
if token_id >= audio_offset: | |
audio_tokens.append(token_id - audio_offset) | |
if len(audio_tokens) > 0: | |
tts_speech = self.audio_tokenizer.decode( | |
audio_tokens, source_speech_16k=prompt_audio_path | |
) | |
else: | |
tts_speech = None | |
if mtp_inference_mode is not None: | |
self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode | |
return output, tts_speech | |
def run_infer_stream( | |
self, | |
audio_path=None, | |
prompt_audio_path=None, | |
stream_stride=4, | |
max_returned_tokens=4096, | |
sample_rate=16000, | |
request_id="", | |
audio_feats=None, | |
message="", | |
use_past=False, | |
mode="luke", | |
do_sample=False, | |
mtp_inference_mode=None, | |
): | |
if prompt_audio_path is not None: | |
system_message = [ | |
{ | |
"role": "system", | |
"content": f"Your Voice: <|audio|>\n", | |
}, | |
] | |
elif mode == "luke": | |
system_message = self.luke_system_message | |
else: | |
system_message = self.default_system_message | |
if prompt_audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True): | |
# discrete codec | |
audio_tokens = self.audio_tokenizer.encode(prompt_audio_path) | |
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) | |
system_message[-1]["content"] = system_message[-1]["content"].replace( | |
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" | |
) | |
if audio_path is not None: | |
messages = system_message + [ | |
{ | |
"role": "user", | |
"content": message + "\n<|audio|>", | |
}, | |
] | |
else: | |
messages = system_message + [ | |
{ | |
"role": "user", | |
"content": message, | |
}, | |
] | |
if audio_path is not None and self.audio_tokenizer.apply_to_role("user", is_discrete=True): | |
# discrete codec | |
audio_tokens = self.audio_tokenizer.encode(audio_path) | |
audio_tokens = "".join(f"<|audio_{i}|>" for i in audio_tokens) | |
messages[-1]["content"] = messages[-1]["content"].replace( | |
"<|audio|>", f"<|begin_of_audio|>{audio_tokens}<|end_of_audio|>" | |
) | |
input_ids = self.tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=self.add_generation_prompt, | |
) | |
if (audio_path is not None or prompt_audio_path is not None) and self.audio_tokenizer.apply_to_role( | |
"user", is_contiguous=True | |
): | |
# contiguous codec | |
audio_paths = [] | |
if audio_path is not None: | |
audio_paths.append(audio_path) | |
if prompt_audio_path is not None: | |
audio_paths.append(prompt_audio_path) | |
input_ids, audios, audio_indices = add_audio_input_contiguous( | |
input_ids, audio_paths, self.tokenizer, self.audio_tokenizer | |
) | |
else: | |
audios = None | |
audio_indices = None | |
input_ids = torch.tensor([input_ids], dtype=torch.long).to("cuda") | |
print("input", self.tokenizer.decode(input_ids[0], skip_special_tokens=False), flush=True) | |
self.model.generation_config.do_sample = do_sample | |
if mtp_inference_mode is not None: | |
ori_mtp_inference_mode = self.model.generation_config.mtp_inference_mode | |
self.model.generation_config.mtp_inference_mode = mtp_inference_mode | |
streamer = TextAudioIteratorStreamer(self.tokenizer, skip_prompt=True) | |
generation_kwargs = dict( | |
input_ids=input_ids, | |
audios=audios, | |
audio_indices=audio_indices, | |
streamer=streamer, | |
) | |
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# generated_text = "" | |
for new_text in streamer: | |
# generated_text += new_text | |
yield new_text | |
# torch.cuda.synchronize() | |
if mtp_inference_mode is not None: | |
self.model.generation_config.mtp_inference_mode = ori_mtp_inference_mode | |
def benchmark_llm(): | |
for mtp_inference_mode, tag in zip( | |
[ | |
[8192, 0], | |
[1, 4, 3, 8, 4, 10], | |
[1, 10, 4, 10], | |
[1, 10], | |
], | |
[ | |
"Vanilla", | |
"Balance", | |
"Boost", | |
"Turbo", | |
], | |
): | |
print("=" * 100) | |
print("benchmark_llm") | |
print(f"{tag}") | |
s2s_inference.benchmark_forward(mtp_inference_mode) | |
s2s_inference.benchmark_generate(mtp_inference_mode) | |
generated_text = "" | |
for new_text in s2s_inference.benchmark_generate_stream( | |
mtp_inference_mode=mtp_inference_mode | |
): | |
generated_text += new_text | |
# print(new_text, end="", flush=True) | |
def benchmark_sts(): | |
audio_paths = [ | |
"asset/介绍一下上海.wav", | |
"asset/发表一个悲伤的演讲.wav", | |
"asset/发表一个振奋人心的演讲.wav", | |
] | |
for _ in range(10): | |
print("=" * 100) | |
print("benchmark_sts") | |
audio_path = random.choice(audio_paths) | |
print(f"{audio_path}") | |
start = time.time() | |
audio_idx = 0 | |
generated_text = "" | |
all_tts_speech = [] | |
past_tts_speech_len = 0 | |
for new_text in s2s_inference.run_infer_stream(audio_path=audio_path): | |
# print(new_text, end="", flush=True) | |
generated_text += new_text | |
if new_text == "<|end_of_audio|>": | |
audio_tokens = extract_token_ids_as_int(generated_text) | |
tts_speech = s2s_inference.audio_tokenizer.decode(audio_tokens, option_steps=1) | |
tts_speech = tts_speech[past_tts_speech_len:] | |
past_tts_speech_len += len(tts_speech) | |
all_tts_speech.append(tts_speech) | |
end = time.time() | |
if audio_idx == 0: | |
print(audio_tokens) | |
print(f"{audio_idx} audio chunk {end - start}") | |
wav_path = os.path.join(output_dir, audio_path[:-4] + f"_{audio_idx}.wav") | |
os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
audio_idx += 1 | |
start = time.time() | |
wav_path = os.path.join(output_dir, audio_path[:-4] + ".wav") | |
tts_speech = torch.cat(all_tts_speech, dim=0) | |
os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
# ============================================================== | |
# Text | |
def text_task(): | |
for text in [ | |
"How many helicopters can a human eat in one sitting?", | |
"你叫什么名字?", | |
"写一首诗", | |
"介绍一下上海", | |
]: | |
print("=" * 100) | |
print("text_task") | |
print(f"{text=}") | |
output, _ = s2s_inference.run_infer( | |
message=text, | |
mode=None, | |
# do_sample=True, | |
mtp_inference_mode=[8192, 0], | |
) | |
print(f"{output=}", flush=True) | |
# ============================================================== | |
# Text stream | |
def text_stream_task(): | |
for text in [ | |
"你叫什么名字?", | |
]: | |
print("=" * 100) | |
print("text_stream_task") | |
print(f"{text=}") | |
generated_text = "" | |
for new_text in s2s_inference.run_infer_stream( | |
message=text, | |
mode=None, | |
# do_sample=True, | |
mtp_inference_mode=[8192, 0], | |
): | |
generated_text += new_text | |
print(new_text, end="") | |
print("") | |
# ============================================================== | |
# S2S | |
def sts_task(): | |
for audio_path in [ | |
"asset/介绍一下上海.wav", | |
"asset/发表一个悲伤的演讲.wav", | |
"asset/发表一个振奋人心的演讲.wav", | |
"asset/piano.mp3", | |
]: | |
print("=" * 100) | |
print("sts_task") | |
print(f"{audio_path=}") | |
output, tts_speech = s2s_inference.run_infer( | |
audio_path=audio_path, | |
) | |
wav_path = os.path.join(output_dir, audio_path[:-4] + ".wav") | |
os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
# ============================================================== | |
# S2S stream | |
def sts_stream_task(): | |
for audio_path in [ | |
"asset/介绍一下上海.wav", | |
]: | |
print("=" * 100) | |
print("sts_stream_task") | |
print(f"{audio_path=}") | |
generated_text = "" | |
for new_text in s2s_inference.run_infer_stream(audio_path=audio_path): | |
generated_text += new_text | |
print(new_text, end="") | |
print("") | |
audio_decode_time = [] | |
audio_segments = find_audio_segments_regex(generated_text) | |
for audio_idx, audio_segment in enumerate(audio_segments): | |
start = time.time() | |
audio_tokens = extract_token_ids_as_int(audio_segment) | |
# print(audio_tokens) | |
tts_speech = s2s_inference.audio_tokenizer.decode(audio_tokens) | |
end = time.time() | |
audio_decode_time.append(end - start) | |
wav_path = os.path.join(output_dir, audio_path[:-4] + f"_{audio_idx}.wav") | |
os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
# print(f"{audio_decode_time=}") | |
# ============================================================== | |
# ASR | |
def asr_task(): | |
for audio_path in [ | |
"/data/data/wenet-e2e/wenetspeech/data/cuts_TEST_NET.00000000/TES/TEST_NET_Y0000000020_5XD21BihDd8_S00395.wav", | |
"/data/data/wenet-e2e/wenetspeech/data/cuts_TEST_NET.00000000/TES/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00424.wav", | |
"/data/data/wenet-e2e/wenetspeech/data/cuts_TEST_NET.00000000/TES/TEST_NET_Y0000000050_LOLTeK1BNMo_S00045.wav", | |
"/data/data/fixie-ai/librispeech_asr/test.clean/2830-3980-0034.wav", | |
"/data/data/fixie-ai/librispeech_asr/test.clean/237-134500-0040.wav", | |
]: | |
print("=" * 100) | |
print("asr_task") | |
print(f"{audio_path=}") | |
output, tts_speech = s2s_inference.run_infer( | |
audio_path=audio_path, | |
# message="Translate the speech to text.", | |
message="Convert the speech to text.", | |
mode=None, | |
) | |
print(f"{output=}", flush=True) | |
# ============================================================== | |
# TTS | |
def tts_task(): | |
TTS_texts = [ | |
"我们将为全球城市的可持续发展贡献力量。", | |
"通天河 灵感大王", | |
"他本是我莲花池里养大的金鱼,每日浮头听经,修成手段。那一柄九瓣铜锤,乃是一枝未开的菡萏,被他运炼成兵。不知是那一日,海潮泛涨,走到此间。我今早扶栏看花,却不见这厮出拜,掐指巡纹,算着他在此成精,害你师父,故此未及梳妆,运神功,织个竹篮儿擒他。", | |
"一二三四五六七八九十", | |
"One Two Tree Four Five Six Seven Eight Night Ten", | |
"1 2 3 4 5 6 7 8 9 10", | |
"12345678910", | |
"两个黄鹂鸣翠柳,一行白鹭上青天。窗含西岭千秋雪,门泊东吴万里船。", | |
"坡上立着一只鹅,坡下就是一条河。宽宽的河,肥肥的鹅,鹅要过河,河要渡鹅不知是鹅过河,还是河渡鹅?", | |
"扁担长,板凳宽,扁担没有板凳宽,板凳没有扁担长。扁担绑在板凳上,板凳不让扁担绑在板凳上。", | |
"化肥会挥发,黑化肥发灰,灰化肥发黑。黑化肥发灰会挥发;灰化肥挥发会发黑。黑化肥挥发发灰会花飞;灰化肥挥发发黑会飞花,黑灰化肥会挥发发灰黑讳为花飞;灰黑化肥会挥发发黑灰为讳飞花。", | |
"圆桌儿、方桌儿没有腿儿,墨水瓶儿里没有水儿,花瓶里有花儿没有叶儿,练习本儿上写字儿没有准儿,甘蔗好吃净是节儿。西瓜挺大没有味儿,坛儿里的小米儿长了虫儿,鸡毛掸子成了棍儿,水缸沿儿上系围裙儿,耗子打更猫打盹儿,新买的小褂儿没钉扣儿,奶奶想说没有劲儿。", | |
"起床歌:小宝宝,起得早,睁开眼,眯眯笑,咿呀呀,学说话,伸伸手,要人抱。穿衣歌小胳膊,穿袖子,穿上衣,扣扣子,小脚丫,穿裤子,穿上袜子穿鞋子。小镜子-小镜子,圆又圆,看宝宝,露笑脸。闭上眼,做个梦,变月亮,挂上天。小铃铛叮铃铃,叮铃铃,一会远,一会近。小宝宝,耳朵灵,听铃声,找到铃。学画画小宝宝,学画画,大蜡笔,手中拿,画小鸭,叫嘎嘎,画小马,骑回家。大鞋子大鞋子,像只船,爸爸穿,我也穿,一二一,向前走,走呀走,翻了船。逛公园逛公园,宝宝笑,东看看,西瞧瞧,花儿香,鸟儿叫,小草绿,小树摇。看画报小娃娃,看画报,睁大眼,仔细瞧,布娃娃,哈哈笑,伸伸手,要你抱。搭积木大积木,红黄兰,小宝宝,最爱玩,搭火车,钻山洞,盖高楼,连着天。小汽车小汽车,嘀嘀嘀,开过来,开过去,小宝宝,当司机,送妈妈,上班去。藏猫猫儿歌:躲猫猫,躲猫猫, 猫猫、猫猫在哪里?喵……猫咪在这里。", | |
] | |
for text in TTS_texts: | |
print("=" * 100) | |
print("tts_task") | |
print(f"{text=}") | |
output, tts_speech = s2s_inference.run_infer( | |
message="Convert the text to speech.\n" + text, | |
mode=None, | |
do_sample=True, | |
) | |
wav_path = os.path.join(output_dir, text[:16] + ".wav") | |
os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
# ============================================================== | |
# Clone TTS | |
for text in TTS_texts: | |
for prompt_audio_path in [ | |
"asset/2631296891109983590.wav", | |
"asset/379838640-d5ff0815-74f8-4738-b0f1-477cfc8dcc2d.wav", | |
"asset/4202818730519913143.wav", | |
]: | |
print("=" * 100) | |
print("tts_task") | |
print(f"{text=} {prompt_audio_path=}") | |
output, tts_speech = s2s_inference.run_infer( | |
prompt_audio_path=prompt_audio_path, | |
# message="Translate the text to speech.\n" + text, | |
message="Convert the text to speech.\n" + text, | |
mode=None, | |
do_sample=True, | |
) | |
wav_path = os.path.join(output_dir, prompt_audio_path[:16] + "_" + text[:16] + ".wav") | |
os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
# ============================================================== | |
# TTS stream | |
def tts_stream_task(): | |
TTS_texts = [ | |
"他本是我莲花池里养大的金鱼,每日浮头听经,修成手段。那一柄九瓣铜锤,乃是一枝未开的菡萏,被他运炼成兵。不知是那一日,海潮泛涨,走到此间。我今早扶栏看花,却不见这厮出拜,掐指巡纹,算着他在此成精,害你师父,故此未及梳妆,运神功,织个竹篮儿擒他。", | |
] | |
for text in TTS_texts: | |
print("=" * 100) | |
print("tts_stream_task") | |
print(f"{text=}") | |
generated_text = "" | |
for new_text in s2s_inference.run_infer_stream( | |
message="Convert the text to speech.\n" + text, | |
mode=None, | |
do_sample=True, | |
): | |
generated_text += new_text | |
print(new_text, end="") | |
print("") | |
audio_segments = find_audio_segments_regex(generated_text) | |
for audio_idx, audio_segment in enumerate(audio_segments): | |
audio_tokens = extract_token_ids_as_int(audio_segment) | |
# print(audio_tokens) | |
tts_speech = s2s_inference.audio_tokenizer.decode(audio_tokens) | |
wav_path = os.path.join(output_dir, text[:16] + f"_{audio_idx}.wav") | |
os.makedirs(os.path.dirname(wav_path), exist_ok=True) | |
torchaudio.save(wav_path, tts_speech.unsqueeze(0), 22050, format="wav") | |
s2s_inference = S2SInference( | |
model_name_or_path, audio_tokenizer_path, audio_tokenizer_type, flow_path=flow_path | |
) | |
text_task() | |
text_stream_task() | |
sts_task() | |
sts_stream_task() | |
asr_task() | |
tts_task() | |
tts_stream_task() | |
benchmark_sts() | |
benchmark_llm() | |