YanshekWoo commited on
Commit
9e02690
1 Parent(s): 2d4c0df

initial app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import List, Optional
3
+ from transformers import BertTokenizer, BartForConditionalGeneration
4
+
5
+ title = "HIT-TMG/dialogue-bart-large-chinese"
6
+ description = """
7
+ This is a seq2seq model fine-tuned on several Chinese dialogue datasets, from bart-large-chinese. \n
8
+ See some details of model card at https://huggingface.co/HIT-TMG/dialogue-bart-large-chinese . \n\n
9
+ Besides starting the conversation from scratch, you can also input the whole dialogue history utterance by utterance seperated by '[SEP]'. \n
10
+ (e.g. "可以认识一下吗[SEP]当然可以啦,你好。[SEP]嘿嘿你好,请问你最近在忙什么呢?[SEP]我最近养了一只狗狗,我在训练它呢。") \n
11
+ """
12
+
13
+
14
+ tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese-DuSinc")
15
+ model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-large-chinese-DuSinc")
16
+
17
+ tokenizer.truncation_side = 'left'
18
+ max_length = 512
19
+
20
+
21
+ def chat_func(input_utterance: str, history: Optional[List[str]] = None):
22
+ if history is not None:
23
+ history.extend(input_utterance.split(tokenizer.sep_token))
24
+ else:
25
+ history = input_utterance.split(tokenizer.sep_token)
26
+
27
+ history_str = "[history] " + tokenizer.sep_token.join(history)
28
+
29
+ input_ids = tokenizer(history_str,
30
+ return_tensors='pt',
31
+ truncation=True,
32
+ max_length=max_length).input_ids
33
+
34
+ output_ids = model.generate(input_ids,
35
+ max_new_tokens=30)[0]
36
+ response = tokenizer.decode(output_ids, skip_special_tokens=True)
37
+
38
+ history.append(response)
39
+
40
+
41
+ if len(history) % 2 == 0:
42
+ display_utterances = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
43
+ else:
44
+ display_utterances = [("", history[0])] + [(history[i], history[i + 1]) for i in range(1, len(history) - 1, 2)]
45
+
46
+ return display_utterances, history
47
+
48
+
49
+ demo = gr.Interface(fn=chat_func,
50
+ title=title,
51
+ description=description,
52
+ inputs=[gr.Textbox(lines=1, placeholder="Input current utterance"), "state"],
53
+ outputs=["chatbot", "state"])
54
+
55
+
56
+ if __name__ == "__main__":
57
+ demo.launch()
58
+