zxdu20 commited on
Commit
d2bbc82
1 Parent(s): 2449bdc

Fix Chinese punctuation

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +18 -4
modeling_chatglm.py CHANGED
@@ -4,6 +4,7 @@ import math
4
  import copy
5
  import os
6
  import warnings
 
7
 
8
  import torch
9
  import torch.utils.checkpoint
@@ -1085,6 +1086,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1085
  for layer_past in past
1086
  )
1087
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1088
  @torch.no_grad()
1089
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1090
  do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
@@ -1107,8 +1123,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1107
  outputs = self.generate(**input_ids, **gen_kwargs)
1108
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1109
  response = tokenizer.decode(outputs)
1110
- response = response.strip()
1111
- response = response.replace("[[训练时间]]", "2023年")
1112
  history = history + [(query, response)]
1113
  return response, history
1114
 
@@ -1134,8 +1149,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1134
  for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1135
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1136
  response = tokenizer.decode(outputs)
1137
- response = response.strip()
1138
- response = response.replace("[[训练时间]]", "2023年")
1139
  new_history = history + [(query, response)]
1140
  yield response, new_history
1141
 
 
4
  import copy
5
  import os
6
  import warnings
7
+ import re
8
 
9
  import torch
10
  import torch.utils.checkpoint
 
1086
  for layer_past in past
1087
  )
1088
 
1089
+ def process_response(self, response):
1090
+ response = response.strip()
1091
+ response = response.replace("[[训练时间]]", "2023年")
1092
+ punkts = [
1093
+ [",", ","],
1094
+ ["!", "!"],
1095
+ [":", ":"],
1096
+ [";", ";"],
1097
+ ["\?", "?"],
1098
+ ]
1099
+ for item in punkts:
1100
+ response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
1101
+ response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
1102
+ return response
1103
+
1104
  @torch.no_grad()
1105
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1106
  do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
 
1123
  outputs = self.generate(**input_ids, **gen_kwargs)
1124
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1125
  response = tokenizer.decode(outputs)
1126
+ response = self.process_response(response)
 
1127
  history = history + [(query, response)]
1128
  return response, history
1129
 
 
1149
  for outputs in self.stream_generate(**input_ids, **gen_kwargs):
1150
  outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
1151
  response = tokenizer.decode(outputs)
1152
+ response = self.process_response(response)
 
1153
  new_history = history + [(query, response)]
1154
  yield response, new_history
1155