xusong28 commited on
Commit
d0547d2
1 Parent(s): dc5d472
Files changed (3) hide show
  1. demo_chatbot_jddc.py +4 -1
  2. demo_corrector.py +4 -2
  3. demo_sum.py +4 -0
demo_chatbot_jddc.py CHANGED
@@ -17,11 +17,14 @@ tokenizer = BertTokenizer.from_pretrained("eson/kplug-base-jddc")
17
 
18
 
19
  def predict(input, history=[]):
 
 
 
20
  # append the new user input tokens to the chat history
21
  history = history + [input] # history如果包含错误的response,可能会造成误差传递
22
 
23
  # tokenize the new input sentence
24
- bot_input_ids = tokenizer.encode("".join(history)[-500:], return_tensors='pt') # 直接用空格拼接的,太粗糙,且不区分角色。
25
 
26
  # bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
27
 
17
 
18
 
19
  def predict(input, history=[]):
20
+ """
21
+ 拼接方案:直接拼接history作为输入,不区分角色。虽然简单粗糙,但是encoder-decoder架构不会混淆输入和输出(如果是gpt架构就需要区分角色了)。
22
+ """
23
  # append the new user input tokens to the chat history
24
  history = history + [input] # history如果包含错误的response,可能会造成误差传递
25
 
26
  # tokenize the new input sentence
27
+ bot_input_ids = tokenizer.encode("".join(history)[-500:], return_tensors='pt')
28
 
29
  # bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
30
 
demo_corrector.py CHANGED
@@ -52,7 +52,9 @@ def mock_data():
52
 
53
 
54
  def correct(sent):
55
-
 
 
56
  corrected_sent, errs = corrector.bert_correct(sent)
57
  # corrected_sent, errs = mock_data()
58
  print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs))
@@ -79,7 +81,7 @@ corr_iface = gr.Interface(
79
 
80
  ),
81
  gr.JSON(
82
- label="JSON Output"
83
  )
84
  ],
85
  examples=error_sentences,
52
 
53
 
54
  def correct(sent):
55
+ """
56
+ {"text": sent, "entities": [{}, {}] } 是 gradio 要求的格式,详见 https://www.gradio.app/docs/highlightedtext
57
+ """
58
  corrected_sent, errs = corrector.bert_correct(sent)
59
  # corrected_sent, errs = mock_data()
60
  print("original sentence:{} => {}, err:{}".format(sent, corrected_sent, errs))
81
 
82
  ),
83
  gr.JSON(
84
+ # label="JSON Output"
85
  )
86
  ],
87
  examples=error_sentences,
demo_sum.py CHANGED
@@ -109,10 +109,14 @@ gen_mode_params = {
109
  "num_beams": 10,
110
  "do_sample": False,
111
  },
 
 
112
  "contrastive search": {
113
  "top_k": 4,
114
  "penalty_alpha": 0.2,
115
  },
 
 
116
  "diverse beam search": {
117
  "num_beams": 5,
118
  "num_beam_groups": 5,
109
  "num_beams": 10,
110
  "do_sample": False,
111
  },
112
+
113
+ # 算法? 复杂度?
114
  "contrastive search": {
115
  "top_k": 4,
116
  "penalty_alpha": 0.2,
117
  },
118
+
119
+ # 算法? 复杂度?
120
  "diverse beam search": {
121
  "num_beams": 5,
122
  "num_beam_groups": 5,