deprecate argument stream in model.chat()
Browse files- modeling_qwen.py +23 -36
modeling_qwen.py
CHANGED
@@ -60,6 +60,12 @@ If you are directly using the model downloaded from Huggingface, please make sur
|
|
60 |
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
|
61 |
"""
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
apply_rotary_emb_func = None
|
64 |
rms_norm = None
|
65 |
flash_attn_unpadded_func = None
|
@@ -977,10 +983,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
977 |
history: Optional[HistoryType],
|
978 |
system: str = "You are a helpful assistant.",
|
979 |
append_history: bool = True,
|
980 |
-
stream: Optional[bool] =
|
981 |
stop_words_ids: Optional[List[List[int]]] = None,
|
982 |
**kwargs,
|
983 |
) -> Tuple[str, HistoryType]:
|
|
|
984 |
assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
|
985 |
if history is None:
|
986 |
history = []
|
@@ -1000,41 +1007,21 @@ class QWenLMHeadModel(QWenPreTrainedModel):
|
|
1000 |
self.generation_config.chat_format, tokenizer
|
1001 |
))
|
1002 |
input_ids = torch.tensor([context_tokens]).to(self.device)
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
1009 |
-
|
1010 |
-
|
1011 |
-
|
1012 |
-
|
1013 |
-
|
1014 |
-
|
1015 |
-
|
1016 |
-
|
1017 |
-
|
1018 |
-
break
|
1019 |
-
yield tokenizer.decode(outputs, skip_special_tokens=True)
|
1020 |
-
|
1021 |
-
return stream_generator()
|
1022 |
-
else:
|
1023 |
-
outputs = self.generate(
|
1024 |
-
input_ids,
|
1025 |
-
stop_words_ids = stop_words_ids,
|
1026 |
-
return_dict_in_generate = False,
|
1027 |
-
**kwargs,
|
1028 |
-
)
|
1029 |
-
|
1030 |
-
response = decode_tokens(
|
1031 |
-
outputs[0],
|
1032 |
-
tokenizer,
|
1033 |
-
raw_text_len=len(raw_text),
|
1034 |
-
context_length=len(context_tokens),
|
1035 |
-
chat_format=self.generation_config.chat_format,
|
1036 |
-
verbose=False,
|
1037 |
-
)
|
1038 |
|
1039 |
if append_history:
|
1040 |
history.append((query, response))
|
|
|
60 |
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
|
61 |
"""
|
62 |
|
63 |
+
_SENTINEL = object()
|
64 |
+
_ERROR_STREAM_IN_CHAT = """\
|
65 |
+
Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
|
66 |
+
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
|
67 |
+
"""
|
68 |
+
|
69 |
apply_rotary_emb_func = None
|
70 |
rms_norm = None
|
71 |
flash_attn_unpadded_func = None
|
|
|
983 |
history: Optional[HistoryType],
|
984 |
system: str = "You are a helpful assistant.",
|
985 |
append_history: bool = True,
|
986 |
+
stream: Optional[bool] = _SENTINEL,
|
987 |
stop_words_ids: Optional[List[List[int]]] = None,
|
988 |
**kwargs,
|
989 |
) -> Tuple[str, HistoryType]:
|
990 |
+
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
|
991 |
assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
|
992 |
if history is None:
|
993 |
history = []
|
|
|
1007 |
self.generation_config.chat_format, tokenizer
|
1008 |
))
|
1009 |
input_ids = torch.tensor([context_tokens]).to(self.device)
|
1010 |
+
outputs = self.generate(
|
1011 |
+
input_ids,
|
1012 |
+
stop_words_ids = stop_words_ids,
|
1013 |
+
return_dict_in_generate = False,
|
1014 |
+
**kwargs,
|
1015 |
+
)
|
1016 |
+
|
1017 |
+
response = decode_tokens(
|
1018 |
+
outputs[0],
|
1019 |
+
tokenizer,
|
1020 |
+
raw_text_len=len(raw_text),
|
1021 |
+
context_length=len(context_tokens),
|
1022 |
+
chat_format=self.generation_config.chat_format,
|
1023 |
+
verbose=False,
|
1024 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1025 |
|
1026 |
if append_history:
|
1027 |
history.append((query, response))
|