File size: 2,001 Bytes
b71ae80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import paddle

import gradio as gr
from paddlenlp.transformers import (UnifiedTransformerLMHeadModel,
                                    UnifiedTransformerTokenizer)
from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.tts.infer import TTSExecutor

asr = ASRExecutor()
tts = TTSExecutor()
# warmup ASR and TTS
print(tts(text=asr("zh.wav", force_yes=True)))
model_name_or_path = 'plato-mini'
model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path)
tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path)
model.eval()

def chat(audio, history):
    message = asr(audio, force_yes=True)
    history = history or []
    history_input = [text for round in history for text in round]
    history_input.append(message)
    inputs = tokenizer.dialogue_encode(history_input,
                                               add_start_token_as_response=True,
                                               return_tensors=True,
                                               is_split_into_words=False)
    inputs['input_ids'] = inputs['input_ids'].astype('int64')
    ids, scores = model.generate(
        input_ids=inputs['input_ids'],
        token_type_ids=inputs['token_type_ids'],
        position_ids=inputs['position_ids'],
        attention_mask=inputs['attention_mask'],
        decode_strategy="sampling",
        num_return_sequences=5,
        top_p=0.95)
    index = paddle.argmax(scores)
    response = tokenizer.decode(ids[index], skip_special_tokens=True).replace(" ", "")
    history.append((message, response))
    output_file = tts(text=response, output="output.wav")
    return output_file, history, history

demo = gr.Interface(
    chat,
    inputs=[
        gr.Audio(source="microphone", type="filepath"),
        "state"],
    outputs=[
        gr.Audio(type="filepath"),
        gr.Chatbot().style(color_map=("green", "pink")),
        "state"
        ],
    allow_flagging="never",
)
if __name__ == "__main__":
    demo.launch()