ShiJueXiaofei commited on
Commit
2cdf703
1 Parent(s): 8fd7fba

fix when use_cache = False,inference 乱码

Browse files

当加载原始模型,设置 use_cache = False 时,对next_token的预测,input_ids的截取只判断了 is_first_forward ,仍然截取处理,只使用最新的token写入input_ids。此时没有past_key_value参数,会导致模型推理乱码。
应该 判断 is_first_forward == False and self.config.use_cache == True 的时候,才能截取最新预测的token,传入model,否则要传入前面原始文本序列及已经预测的token。

Files changed (1) hide show
  1. modeling_chatglm.py +3 -2
modeling_chatglm.py CHANGED
@@ -904,8 +904,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
904
  if position_ids is None:
905
  position_ids = self.get_position_ids(input_ids, device=input_ids.device)
906
  if not is_first_forward:
907
- position_ids = position_ids[..., -1:]
908
- input_ids = input_ids[:, -1:]
 
909
  return {
910
  "input_ids": input_ids,
911
  "past_key_values": past_key_values,
 
904
  if position_ids is None:
905
  position_ids = self.get_position_ids(input_ids, device=input_ids.device)
906
  if not is_first_forward:
907
+ if self.config.use_cache:
908
+ position_ids = position_ids[..., -1:]
909
+ input_ids = input_ids[:, -1:]
910
  return {
911
  "input_ids": input_ids,
912
  "past_key_values": past_key_values,