hesijun commited on
Commit
b71ae80
·
1 Parent(s): e1d6e16

initial commit

Browse files
Files changed (3) hide show
  1. app.py +55 -0
  2. requirements.txt +3 -0
  3. zh.wav +0 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import paddle
2
+
3
+ import gradio as gr
4
+ from paddlenlp.transformers import (UnifiedTransformerLMHeadModel,
5
+ UnifiedTransformerTokenizer)
6
+ from paddlespeech.cli.asr.infer import ASRExecutor
7
+ from paddlespeech.cli.tts.infer import TTSExecutor
8
+
9
+ asr = ASRExecutor()
10
+ tts = TTSExecutor()
11
+ # warmup ASR and TTS
12
+ print(tts(text=asr("zh.wav", force_yes=True)))
13
+ model_name_or_path = 'plato-mini'
14
+ model = UnifiedTransformerLMHeadModel.from_pretrained(model_name_or_path)
15
+ tokenizer = UnifiedTransformerTokenizer.from_pretrained(model_name_or_path)
16
+ model.eval()
17
+
18
+ def chat(audio, history):
19
+ message = asr(audio, force_yes=True)
20
+ history = history or []
21
+ history_input = [text for round in history for text in round]
22
+ history_input.append(message)
23
+ inputs = tokenizer.dialogue_encode(history_input,
24
+ add_start_token_as_response=True,
25
+ return_tensors=True,
26
+ is_split_into_words=False)
27
+ inputs['input_ids'] = inputs['input_ids'].astype('int64')
28
+ ids, scores = model.generate(
29
+ input_ids=inputs['input_ids'],
30
+ token_type_ids=inputs['token_type_ids'],
31
+ position_ids=inputs['position_ids'],
32
+ attention_mask=inputs['attention_mask'],
33
+ decode_strategy="sampling",
34
+ num_return_sequences=5,
35
+ top_p=0.95)
36
+ index = paddle.argmax(scores)
37
+ response = tokenizer.decode(ids[index], skip_special_tokens=True).replace(" ", "")
38
+ history.append((message, response))
39
+ output_file = tts(text=response, output="output.wav")
40
+ return output_file, history, history
41
+
42
+ demo = gr.Interface(
43
+ chat,
44
+ inputs=[
45
+ gr.Audio(source="microphone", type="filepath"),
46
+ "state"],
47
+ outputs=[
48
+ gr.Audio(type="filepath"),
49
+ gr.Chatbot().style(color_map=("green", "pink")),
50
+ "state"
51
+ ],
52
+ allow_flagging="never",
53
+ )
54
+ if __name__ == "__main__":
55
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ paddlepaddle
2
+ paddlenlp
3
+ paddlespeech
zh.wav ADDED
Binary file (160 kB). View file