changwangss commited on
Commit
d223af7
1 Parent(s): e580bc8

update modeling_baichuan.py for torchscript mode with past_kv

Browse files

to enable model inference with use_cache and return_dict from model.config.

Files changed (1) hide show
  1. modeling_baichuan.py +4 -2
modeling_baichuan.py CHANGED
@@ -287,7 +287,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
287
  use_cache: Optional[bool] = False,
288
  output_attentions: Optional[bool] = False,
289
  output_hidden_states: Optional[bool] = False,
290
- return_dict: Optional[bool] = True,
291
  ) -> Union[Tuple, BaseModelOutputWithPast]:
292
 
293
  if input_ids is not None and inputs_embeds is not None:
@@ -299,6 +299,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
299
  else:
300
  raise ValueError("You need to provide input_ids or inputs_embeds")
301
 
 
 
302
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
303
 
304
  seq_length_with_past = seq_length
@@ -439,7 +441,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
439
  use_cache: Optional[bool] = None,
440
  output_attentions: Optional[bool] = False,
441
  output_hidden_states: Optional[bool] = False,
442
- return_dict: Optional[bool] = True,
443
  **kwargs
444
  ) -> Union[Tuple, CausalLMOutputWithPast]:
445
 
 
287
  use_cache: Optional[bool] = False,
288
  output_attentions: Optional[bool] = False,
289
  output_hidden_states: Optional[bool] = False,
290
+ return_dict: Optional[bool] = None,
291
  ) -> Union[Tuple, BaseModelOutputWithPast]:
292
 
293
  if input_ids is not None and inputs_embeds is not None:
 
299
  else:
300
  raise ValueError("You need to provide input_ids or inputs_embeds")
301
 
302
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
303
+
304
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
305
 
306
  seq_length_with_past = seq_length
 
441
  use_cache: Optional[bool] = None,
442
  output_attentions: Optional[bool] = False,
443
  output_hidden_states: Optional[bool] = False,
444
+ return_dict: Optional[bool] = None,
445
  **kwargs
446
  ) -> Union[Tuple, CausalLMOutputWithPast]:
447