xusong28
commited on
Commit
•
d0547d2
1
Parent(s):
dc5d472
update
Browse files- demo_chatbot_jddc.py +4 -1
- demo_corrector.py +4 -2
- demo_sum.py +4 -0
demo_chatbot_jddc.py
CHANGED
@@ -17,11 +17,14 @@ tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-jddc")
|
|
17 |
|
18 |
|
19 |
def predict(input, history=[]):
|
|
|
|
|
|
|
20 |
# append the new user input tokens to the chat history
|
21 |
history = history + [input] # history如果包含错误的response,可能会造成误差传递
|
22 |
|
23 |
# tokenize the new input sentence
|
24 |
-
bot_input_ids = tokenizer.encode("".join(history)[-500:], return_tensors='pt')
|
25 |
|
26 |
# bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
|
27 |
|
17 |
|
18 |
|
19 |
def predict(input, history=[]):
|
20 |
+
"""
|
21 |
+
拼接方案:直接拼接history作为输入,不区分角色。虽然简单粗糙,但是encoder-decoder架构不会混淆输入和输出(如果是gpt架构就需要区分角色了)。
|
22 |
+
"""
|
23 |
# append the new user input tokens to the chat history
|
24 |
history = history + [input] # history如果包含错误的response,可能会造成误差传递
|
25 |
|
26 |
# tokenize the new input sentence
|
27 |
+
bot_input_ids = tokenizer.encode("".join(history)[-500:], return_tensors='pt')
|
28 |
|
29 |
# bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
|
30 |
|
demo_corrector.py
CHANGED
@@ -52,7 +52,9 @@ def mock_data():
|
|
52 |
|
53 |
|
54 |
def correct(sent):
|
55 |
-
|
|
|
|
|
56 |
corrected_sent, errs = corrector.bert_correct(sent)
|
57 |
# corrected_sent, errs = mock_data()
|
58 |
print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs))
|
@@ -79,7 +81,7 @@ corr_iface = gr.Interface(
|
|
79 |
|
80 |
),
|
81 |
gr.JSON(
|
82 |
-
label="JSON Output"
|
83 |
)
|
84 |
],
|
85 |
examples=error_sentences,
|
52 |
|
53 |
|
54 |
def correct(sent):
|
55 |
+
"""
|
56 |
+
{"text": sent, "entities": [{}, {}] } 是 gradio 要求的格式,详见 https://www.gradio.app/docs/highlightedtext
|
57 |
+
"""
|
58 |
corrected_sent, errs = corrector.bert_correct(sent)
|
59 |
# corrected_sent, errs = mock_data()
|
60 |
print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs))
|
81 |
|
82 |
),
|
83 |
gr.JSON(
|
84 |
+
# label="JSON Output"
|
85 |
)
|
86 |
],
|
87 |
examples=error_sentences,
|
demo_sum.py
CHANGED
@@ -109,10 +109,14 @@ gen_mode_params = {
|
|
109 |
"num_beams": 10,
|
110 |
"do_sample": False,
|
111 |
},
|
|
|
|
|
112 |
"contrastive search": {
|
113 |
"top_k": 4,
|
114 |
"penalty_alpha": 0.2,
|
115 |
},
|
|
|
|
|
116 |
"diverse beam search": {
|
117 |
"num_beams": 5,
|
118 |
"num_beam_groups": 5,
|
109 |
"num_beams": 10,
|
110 |
"do_sample": False,
|
111 |
},
|
112 |
+
|
113 |
+
# 算法? 复杂度?
|
114 |
"contrastive search": {
|
115 |
"top_k": 4,
|
116 |
"penalty_alpha": 0.2,
|
117 |
},
|
118 |
+
|
119 |
+
# 算法? 复杂度?
|
120 |
"diverse beam search": {
|
121 |
"num_beams": 5,
|
122 |
"num_beam_groups": 5,
|