changwangss
commited on
Commit
•
d223af7
1
Parent(s):
e580bc8
update modeling_baichuan.py for torchscript mode with past_kv
Browse filesto enable model inference with use_cache and return_dict from model.config.
- modeling_baichuan.py +4 -2
modeling_baichuan.py
CHANGED
@@ -287,7 +287,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
287 |
use_cache: Optional[bool] = False,
|
288 |
output_attentions: Optional[bool] = False,
|
289 |
output_hidden_states: Optional[bool] = False,
|
290 |
-
return_dict: Optional[bool] =
|
291 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
292 |
|
293 |
if input_ids is not None and inputs_embeds is not None:
|
@@ -299,6 +299,8 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
299 |
else:
|
300 |
raise ValueError("You need to provide input_ids or inputs_embeds")
|
301 |
|
|
|
|
|
302 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
303 |
|
304 |
seq_length_with_past = seq_length
|
@@ -439,7 +441,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
439 |
use_cache: Optional[bool] = None,
|
440 |
output_attentions: Optional[bool] = False,
|
441 |
output_hidden_states: Optional[bool] = False,
|
442 |
-
return_dict: Optional[bool] =
|
443 |
**kwargs
|
444 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
445 |
|
|
|
287 |
use_cache: Optional[bool] = False,
|
288 |
output_attentions: Optional[bool] = False,
|
289 |
output_hidden_states: Optional[bool] = False,
|
290 |
+
return_dict: Optional[bool] = None,
|
291 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
292 |
|
293 |
if input_ids is not None and inputs_embeds is not None:
|
|
|
299 |
else:
|
300 |
raise ValueError("You need to provide input_ids or inputs_embeds")
|
301 |
|
302 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
303 |
+
|
304 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
305 |
|
306 |
seq_length_with_past = seq_length
|
|
|
441 |
use_cache: Optional[bool] = None,
|
442 |
output_attentions: Optional[bool] = False,
|
443 |
output_hidden_states: Optional[bool] = False,
|
444 |
+
return_dict: Optional[bool] = None,
|
445 |
**kwargs
|
446 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
447 |
|