zxdu20 commited on
Commit
debaf00
1 Parent(s): 3ba9437

Fix Chinese punctuation

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +18 -4
modeling_chatglm.py CHANGED
@@ -4,6 +4,7 @@ import math
4
  import copy
5
  import os
6
  import warnings
 
7
 
8
  import torch
9
  import torch.utils.checkpoint
@@ -1099,6 +1100,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1099
  for layer_past in past
1100
  )
1101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1102
  @torch.no_grad()
1103
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1104
  do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
@@ -1121,8 +1137,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1121
  outputs = self.generate(**input_ids, **gen_kwargs)
1122
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1123
  response = tokenizer.decode(outputs)
1124
- response = response.strip()
1125
- response = response.replace("[[训练时间]]", "2023年")
1126
  history = history + [(query, response)]
1127
  return response, history
1128
 
@@ -1148,8 +1163,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1148
  for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1149
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1150
  response = tokenizer.decode(outputs)
1151
- response = response.strip()
1152
- response = response.replace("[[训练时间]]", "2023年")
1153
  new_history = history + [(query, response)]
1154
  yield response, new_history
1155
 
 
4
  import copy
5
  import os
6
  import warnings
7
+ import re
8
 
9
  import torch
10
  import torch.utils.checkpoint
 
1100
  for layer_past in past
1101
  )
1102
 
1103
+ def process_response(self, response):
1104
+ response = response.strip()
1105
+ response = response.replace("[[训练时间]]", "2023年")
1106
+ punkts = [
1107
+ [",", ","],
1108
+ ["!", "!"],
1109
+ [":", ":"],
1110
+ [";", ";"],
1111
+ ["\?", "?"],
1112
+ ]
1113
+ for item in punkts:
1114
+ response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
1115
+ response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
1116
+ return response
1117
+
1118
  @torch.no_grad()
1119
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1120
  do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
 
1137
  outputs = self.generate(**input_ids, **gen_kwargs)
1138
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1139
  response = tokenizer.decode(outputs)
1140
+ response = self.process_response(response)
 
1141
  history = history + [(query, response)]
1142
  return response, history
1143
 
 
1163
  for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1164
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1165
  response = tokenizer.decode(outputs)
1166
+ response = self.process_response(response)
 
1167
  new_history = history + [(query, response)]
1168
  yield response, new_history
1169