changwangss commited on
Commit
b7017d4
1 Parent(s): 4224a07

update modeling_baichuan.py for torchscript mode with past_kv

Browse files

to enable model inference with use_cache and return_dict from model.config.

Files changed (1) hide show
  1. modeling_baichuan.py +4 -2
modeling_baichuan.py CHANGED
@@ -365,7 +365,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
365
  use_cache: Optional[bool] = False,
366
  output_attentions: Optional[bool] = False,
367
  output_hidden_states: Optional[bool] = False,
368
- return_dict: Optional[bool] = True,
369
  ) -> Union[Tuple, BaseModelOutputWithPast]:
370
  if input_ids is not None and inputs_embeds is not None:
371
  raise ValueError(
@@ -378,6 +378,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
378
  else:
379
  raise ValueError("You need to provide input_ids or inputs_embeds")
380
 
 
 
381
  return_dict = (
382
  return_dict if return_dict is not None else self.config.use_return_dict
383
  )
@@ -682,7 +684,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
682
  use_cache: Optional[bool] = None,
683
  output_attentions: Optional[bool] = False,
684
  output_hidden_states: Optional[bool] = False,
685
- return_dict: Optional[bool] = True,
686
  **kwargs,
687
  ) -> Union[Tuple, CausalLMOutputWithPast]:
688
  return_dict = (
 
365
  use_cache: Optional[bool] = False,
366
  output_attentions: Optional[bool] = False,
367
  output_hidden_states: Optional[bool] = False,
368
+ return_dict: Optional[bool] = None,
369
  ) -> Union[Tuple, BaseModelOutputWithPast]:
370
  if input_ids is not None and inputs_embeds is not None:
371
  raise ValueError(
 
378
  else:
379
  raise ValueError("You need to provide input_ids or inputs_embeds")
380
 
381
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
382
+
383
  return_dict = (
384
  return_dict if return_dict is not None else self.config.use_return_dict
385
  )
 
684
  use_cache: Optional[bool] = None,
685
  output_attentions: Optional[bool] = False,
686
  output_hidden_states: Optional[bool] = False,
687
+ return_dict: Optional[bool] = None,
688
  **kwargs,
689
  ) -> Union[Tuple, CausalLMOutputWithPast]:
690
  return_dict = (