|
import json |
|
import os.path |
|
import tempfile |
|
import sys |
|
import re |
|
import uuid |
|
import requests |
|
from argparse import ArgumentParser |
|
|
|
import torchaudio |
|
from transformers import WhisperFeatureExtractor, AutoTokenizer, AutoModel |
|
from speech_tokenizer.modeling_whisper import WhisperVQEncoder |
|
|
|
|
|
|
|
|
|
sys.path.insert(0, "./cosyvoice") |
|
sys.path.insert(0, "./third_party/Matcha-TTS") |
|
|
|
from speech_tokenizer.utils import extract_speech_token |
|
|
|
import gradio as gr |
|
import torch |
|
|
|
audio_token_pattern = re.compile(r"<\|audio_(\d+)\|>") |
|
|
|
from flow_inference import AudioDecoder |
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser() |
|
parser.add_argument("--host", type=str, default="localhost") |
|
parser.add_argument("--port", type=int, default="8888") |
|
parser.add_argument("--flow-path", type=str, default="./glm-4-voice-decoder") |
|
parser.add_argument("--model-path", type=str, default="./glm-4-voice-9b-int8") |
|
parser.add_argument("--tokenizer-path", type=str, default="./glm-4-voice-tokenizer") |
|
args = parser.parse_args() |
|
|
|
flow_config = os.path.join(args.flow_path, "config.yaml") |
|
flow_checkpoint = os.path.join(args.flow_path, 'flow.pt') |
|
hift_checkpoint = os.path.join(args.flow_path, 'hift.pt') |
|
glm_tokenizer = None |
|
device = "cuda" |
|
audio_decoder: AudioDecoder = None |
|
whisper_model, feature_extractor = None, None |
|
|
|
|
|
def initialize_fn(): |
|
global audio_decoder, feature_extractor, whisper_model, glm_model, glm_tokenizer |
|
if audio_decoder is not None: |
|
return |
|
|
|
|
|
glm_tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) |
|
|
|
|
|
audio_decoder = AudioDecoder(config_path=flow_config, flow_ckpt_path=flow_checkpoint, |
|
hift_ckpt_path=hift_checkpoint, |
|
device=device) |
|
|
|
|
|
whisper_model = WhisperVQEncoder.from_pretrained(args.tokenizer_path).eval().to(device) |
|
feature_extractor = WhisperFeatureExtractor.from_pretrained(args.tokenizer_path) |
|
|
|
|
|
def clear_fn(): |
|
return [], [], '', '', '', None, None |
|
|
|
|
|
def inference_fn( |
|
temperature: float, |
|
top_p: float, |
|
max_new_token: int, |
|
input_mode, |
|
audio_path: str | None, |
|
input_text: str | None, |
|
history: list[dict], |
|
previous_input_tokens: str, |
|
previous_completion_tokens: str, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
if input_mode == "audio": |
|
assert audio_path is not None |
|
history.append({"role": "user", "content": {"path": audio_path}}) |
|
audio_tokens = extract_speech_token( |
|
whisper_model, feature_extractor, [audio_path] |
|
)[0] |
|
if len(audio_tokens) == 0: |
|
raise gr.Error("No audio tokens extracted") |
|
audio_tokens = "".join([f"<|audio_{x}|>" for x in audio_tokens]) |
|
audio_tokens = "<|begin_of_audio|>" + audio_tokens + "<|end_of_audio|>" |
|
user_input = audio_tokens |
|
system_prompt = "User will provide you with a speech instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens. " |
|
|
|
else: |
|
assert input_text is not None |
|
history.append({"role": "user", "content": input_text}) |
|
user_input = input_text |
|
system_prompt = "User will provide you with a text instruction. Do it step by step. First, think about the instruction and respond in a interleaved manner, with 13 text token followed by 26 audio tokens." |
|
|
|
|
|
|
|
inputs = previous_input_tokens + previous_completion_tokens |
|
inputs = inputs.strip() |
|
if "<|system|>" not in inputs: |
|
inputs += f"<|system|>\n{system_prompt}" |
|
inputs += f"<|user|>\n{user_input}<|assistant|>streaming_transcription\n" |
|
|
|
with torch.no_grad(): |
|
response = requests.post( |
|
"http://localhost:10000/generate_stream", |
|
data=json.dumps({ |
|
"prompt": inputs, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"max_new_tokens": max_new_token, |
|
}), |
|
stream=True |
|
) |
|
text_tokens, audio_tokens = [], [] |
|
audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>') |
|
end_token_id = glm_tokenizer.convert_tokens_to_ids('<|user|>') |
|
complete_tokens = [] |
|
prompt_speech_feat = torch.zeros(1, 0, 80).to(device) |
|
flow_prompt_speech_token = torch.zeros(1, 0, dtype=torch.int64).to(device) |
|
this_uuid = str(uuid.uuid4()) |
|
tts_speechs = [] |
|
tts_mels = [] |
|
prev_mel = None |
|
is_finalize = False |
|
block_size = 10 |
|
for chunk in response.iter_lines(): |
|
token_id = json.loads(chunk)["token_id"] |
|
if token_id == end_token_id: |
|
is_finalize = True |
|
if len(audio_tokens) >= block_size or (is_finalize and audio_tokens): |
|
block_size = 20 |
|
tts_token = torch.tensor(audio_tokens, device=device).unsqueeze(0) |
|
|
|
if prev_mel is not None: |
|
prompt_speech_feat = torch.cat(tts_mels, dim=-1).transpose(1, 2) |
|
|
|
tts_speech, tts_mel = audio_decoder.token2wav(tts_token, uuid=this_uuid, |
|
prompt_token=flow_prompt_speech_token.to(device), |
|
prompt_feat=prompt_speech_feat.to(device), |
|
finalize=is_finalize) |
|
prev_mel = tts_mel |
|
|
|
tts_speechs.append(tts_speech.squeeze()) |
|
tts_mels.append(tts_mel) |
|
yield history, inputs, '', '', (22050, tts_speech.squeeze().cpu().numpy()), None |
|
flow_prompt_speech_token = torch.cat((flow_prompt_speech_token, tts_token), dim=-1) |
|
audio_tokens = [] |
|
if not is_finalize: |
|
complete_tokens.append(token_id) |
|
if token_id >= audio_offset: |
|
audio_tokens.append(token_id - audio_offset) |
|
else: |
|
text_tokens.append(token_id) |
|
tts_speech = torch.cat(tts_speechs, dim=-1).cpu() |
|
complete_text = glm_tokenizer.decode(complete_tokens, spaces_between_special_tokens=False) |
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: |
|
torchaudio.save(f, tts_speech.unsqueeze(0), 22050, format="wav") |
|
history.append({"role": "assistant", "content": {"path": f.name, "type": "audio/wav"}}) |
|
history.append({"role": "assistant", "content": glm_tokenizer.decode(text_tokens, ignore_special_tokens=False)}) |
|
yield history, inputs, complete_text, '', None, (22050, tts_speech.numpy()) |
|
|
|
|
|
def update_input_interface(input_mode): |
|
if input_mode == "audio": |
|
return [gr.update(visible=True), gr.update(visible=False)] |
|
else: |
|
return [gr.update(visible=False), gr.update(visible=True)] |
|
|
|
|
|
|
|
with gr.Blocks(title="GLM-4-Voice Demo", fill_height=True) as demo: |
|
with gr.Row(): |
|
temperature = gr.Number( |
|
label="Temperature", |
|
value=0.2 |
|
) |
|
|
|
top_p = gr.Number( |
|
label="Top p", |
|
value=0.8 |
|
) |
|
|
|
max_new_token = gr.Number( |
|
label="Max new tokens", |
|
value=2000, |
|
) |
|
|
|
chatbot = gr.Chatbot( |
|
elem_id="chatbot", |
|
bubble_full_width=False, |
|
type="messages", |
|
scale=1, |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_mode = gr.Radio(["audio", "text"], label="Input Mode", value="audio") |
|
audio = gr.Audio(label="Input audio", type='filepath', show_download_button=True, visible=True) |
|
text_input = gr.Textbox(label="Input text", placeholder="Enter your text here...", lines=2, visible=False) |
|
|
|
with gr.Column(): |
|
submit_btn = gr.Button("Submit") |
|
reset_btn = gr.Button("Clear") |
|
output_audio = gr.Audio(label="Play", streaming=True, |
|
autoplay=True, show_download_button=False) |
|
complete_audio = gr.Audio(label="Last Output Audio (If Any)", show_download_button=True) |
|
|
|
|
|
|
|
gr.Markdown("""## Debug Info""") |
|
with gr.Row(): |
|
input_tokens = gr.Textbox( |
|
label=f"Input Tokens", |
|
interactive=False, |
|
) |
|
|
|
completion_tokens = gr.Textbox( |
|
label=f"Completion Tokens", |
|
interactive=False, |
|
) |
|
|
|
detailed_error = gr.Textbox( |
|
label=f"Detailed Error", |
|
interactive=False, |
|
) |
|
|
|
history_state = gr.State([]) |
|
|
|
respond = submit_btn.click( |
|
inference_fn, |
|
inputs=[ |
|
temperature, |
|
top_p, |
|
max_new_token, |
|
input_mode, |
|
audio, |
|
text_input, |
|
history_state, |
|
input_tokens, |
|
completion_tokens, |
|
], |
|
outputs=[history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio] |
|
) |
|
|
|
respond.then(lambda s: s, [history_state], chatbot) |
|
|
|
reset_btn.click(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]) |
|
input_mode.input(clear_fn, outputs=[chatbot, history_state, input_tokens, completion_tokens, detailed_error, output_audio, complete_audio]).then(update_input_interface, inputs=[input_mode], outputs=[audio, text_input]) |
|
|
|
initialize_fn() |
|
|
|
demo.launch( |
|
server_port=args.port, |
|
server_name=args.host |
|
) |
|
|