zxdu20 commited on
Commit
c3dece3
1 Parent(s): 9d1509a

Add logit processor for NaN or Inf scores

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +17 -3
modeling_chatglm.py CHANGED
@@ -3,6 +3,7 @@
3
  import math
4
  import copy
5
  import os
 
6
 
7
  import torch
8
  import torch.utils.checkpoint
@@ -23,8 +24,10 @@ from transformers.modeling_outputs import (
23
  BaseModelOutputWithPastAndCrossAttentions,
24
  )
25
  from transformers.modeling_utils import PreTrainedModel
26
-
27
  from transformers.utils import logging
 
 
 
28
  from .configuration_chatglm import ChatGLMConfig
29
 
30
  # flags required to enable jit fusion kernels
@@ -44,6 +47,14 @@ CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
44
  ]
45
 
46
 
 
 
 
 
 
 
 
 
47
  def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
48
  """Load tf checkpoints in a pytorch model."""
49
  try:
@@ -1078,11 +1089,14 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1078
 
1079
  @torch.no_grad()
1080
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1081
- do_sample=True, top_p=0.7, temperature=0.95, **kwargs):
1082
  if history is None:
1083
  history = []
 
 
 
1084
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1085
- "temperature": temperature, **kwargs}
1086
  if not history:
1087
  prompt = query
1088
  else:
 
3
  import math
4
  import copy
5
  import os
6
+ import time
7
 
8
  import torch
9
  import torch.utils.checkpoint
 
24
  BaseModelOutputWithPastAndCrossAttentions,
25
  )
26
  from transformers.modeling_utils import PreTrainedModel
 
27
  from transformers.utils import logging
28
+ from transformers.generation.logits_process import LogitsProcessor
29
+ from transformers.generation.utils import LogitsProcessorList
30
+
31
  from .configuration_chatglm import ChatGLMConfig
32
 
33
  # flags required to enable jit fusion kernels
 
47
  ]
48
 
49
 
50
+ class InvalidScoreLogitsProcessor(LogitsProcessor):
51
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
52
+ if torch.isnan(scores).any() or torch.isinf(scores).any():
53
+ scores.zero_()
54
+ scores[..., 20005] = 1e5
55
+ return scores
56
+
57
+
58
  def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
59
  """Load tf checkpoints in a pytorch model."""
60
  try:
 
1089
 
1090
  @torch.no_grad()
1091
  def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
1092
+ do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
1093
  if history is None:
1094
  history = []
1095
+ if logits_processor is None:
1096
+ logits_processor = LogitsProcessorList()
1097
+ logits_processor.append(InvalidScoreLogitsProcessor())
1098
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
1099
+ "temperature": temperature, "logits_processor": logits_processor, **kwargs}
1100
  if not history:
1101
  prompt = query
1102
  else: