hesijun
initial commit
b71ae80
import paddle
import gradio as gr
from paddlenlp.transformers import (UnifiedTransformerLMHeadModel,
UnifiedTransformerTokenizer)
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.tts.infer import TTSExecutor
asr = ASRExecutor()
tts = TTSExecutor()
# warmup ASR and TTS
print(tts(text=asr("zh.wav", force_yes=True)))
model_name_or_path = 'plato-mini'
model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path)
tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path)
model.eval()
def chat(audio, history):
message = asr(audio, force_yes=True)
history = history or []
history_input = [text for round in history for text in round]
history_input.append(message)
inputs = tokenizer.dialogue_encode(history_input,
add_start_token_as_response=True,
return_tensors=True,
is_split_into_words=False)
inputs['input_ids'] = inputs['input_ids'].astype('int64')
ids, scores = model.generate(
input_ids=inputs['input_ids'],
token_type_ids=inputs['token_type_ids'],
position_ids=inputs['position_ids'],
attention_mask=inputs['attention_mask'],
decode_strategy="sampling",
num_return_sequences=5,
top_p=0.95)
index = paddle.argmax(scores)
response = tokenizer.decode(ids[index], skip_special_tokens=True).replace(" ", "")
history.append((message, response))
output_file = tts(text=response, output="output.wav")
return output_file, history, history
demo = gr.Interface(
chat,
inputs=[
gr.Audio(source="microphone", type="filepath"),
"state"],
outputs=[
gr.Audio(type="filepath"),
gr.Chatbot().style(color_map=("green", "pink")),
"state"
],
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch()