zxdu20 commited on
Commit
4b7ffbf
1 Parent(s): 373fd6b

No padding for chat function

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +2 -2
modeling_chatglm.py CHANGED
@@ -1243,7 +1243,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1243
  for i, (old_query, response) in enumerate(history):
1244
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1245
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1246
- inputs = tokenizer([prompt], return_tensors="pt", padding=True)
1247
  inputs = inputs.to(self.device)
1248
  outputs = self.generate(**inputs, **gen_kwargs)
1249
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
@@ -1269,7 +1269,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1269
  for i, (old_query, response) in enumerate(history):
1270
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1271
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1272
- inputs = tokenizer([prompt], return_tensors="pt", padding=True)
1273
  inputs = inputs.to(self.device)
1274
  for outputs in self.stream_generate(**inputs, **gen_kwargs):
1275
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
 
1243
  for i, (old_query, response) in enumerate(history):
1244
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1245
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1246
+ inputs = tokenizer([prompt], return_tensors="pt")
1247
  inputs = inputs.to(self.device)
1248
  outputs = self.generate(**inputs, **gen_kwargs)
1249
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
 
1269
  for i, (old_query, response) in enumerate(history):
1270
  prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1271
  prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1272
+ inputs = tokenizer([prompt], return_tensors="pt")
1273
  inputs = inputs.to(self.device)
1274
  for outputs in self.stream_generate(**inputs, **gen_kwargs):
1275
  outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]