File size: 15,360 Bytes
3025ee8 97c8e77 9aa9258 97c8e77 4429a52 97c8e77 c8faa95 ddfe756 c8faa95 97c8e77 ddfe756 97c8e77 ddfe756 97c8e77 ddfe756 5c4598c ddfe756 97c8e77 ddfe756 97c8e77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 |
import time
from typing import List, Iterator
import copy
import json
import logging
import random
import os
from http import HTTPStatus
from pprint import pformat
from typing import Dict, Iterator, List, Optional, Literal, Union
import openai
from openai import OpenAIError, RateLimitError
if openai.__version__.startswith('0.'):
from openai.error import OpenAIError # noqa
else:
from openai import OpenAIError
from qwen_agent.llm.base import ModelServiceError, register_llm
from qwen_agent.llm.function_calling import BaseFnCallModel, simulate_response_completion_with_chat
from qwen_agent.llm.schema import ASSISTANT, Message, FunctionCall
from qwen_agent.log import logger
import datetime
def today_date():
return datetime.date.today().strftime("%Y-%m-%d")
SYSTEM_PROMPT = """You are a deep research assistant. Your core function is to conduct thorough, multi-source investigations into any topic. You must handle both broad, open-domain inquiries and queries within specialized academic fields. For every request, synthesize information from credible, diverse sources to deliver a comprehensive, accurate, and objective response. When you have gathered sufficient information and are ready to provide the definitive response, you must enclose the entire final answer within <answer></answer> tags.
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type": "function", "function": {"name": "search", "description": "Perform Google web searches then returns a string of the top search results. Accepts multiple queries.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries."}}, "required": ["query"]}}}
{"type": "function", "function": {"name": "visit", "description": "Visit webpage(s) and return the summary of the content.", "parameters": {"type": "object", "properties": {"url": {"type": "array", "items": {"type": "string"}, "description": "The URL(s) of the webpage(s) to visit. Can be a single URL or an array of URLs."}, "goal": {"type": "string", "description": "The specific information goal for visiting webpage(s)."}}, "required": ["url", "goal"]}}}
{"type": "function", "function": {"name": "PythonInterpreter", "description": "Executes Python code in a sandboxed environment. To use this tool, you must follow this format:
1. The 'arguments' JSON object must be empty: {}.
2. The Python code to be executed must be placed immediately after the JSON block, enclosed within <code> and </code> tags.
IMPORTANT: Any output you want to see MUST be printed to standard output using the print() function.
Example of a correct call:
<tool_call>
{"name": "PythonInterpreter", "arguments": {}}
<code>
import numpy as np
# Your code here
print(f"The result is: {np.mean([1,2,3])}")
</code>
</tool_call>", "parameters": {"type": "object", "properties": {}, "required": []}}}
{"type": "function", "function": {"name": "google_scholar", "description": "Leverage Google Scholar to retrieve relevant information from academic publications. Accepts multiple queries. This tool will also return results from google search", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string", "description": "The search query."}, "minItems": 1, "description": "The list of search queries for Google Scholar."}}, "required": ["query"]}}}
{"type": "function", "function": {"name": "parse_file", "description": "This is a tool that can be used to parse multiple user uploaded local files such as PDF, DOCX, PPTX, TXT, CSV, XLSX, DOC, ZIP, MP4, MP3.", "parameters": {"type": "object", "properties": {"files": {"type": "array", "items": {"type": "string"}, "description": "The file name of the user uploaded local files to be parsed."}}, "required": ["files"]}}}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
"""
@register_llm('oai')
class TextChatAtOAI(BaseFnCallModel):
def __init__(self, cfg: Optional[Dict] = None):
super().__init__(cfg)
self.model = self.model or 'gpt-4o-mini'
cfg = cfg or {}
api_base = cfg.get('api_base')
api_base = api_base or cfg.get('base_url')
api_base = api_base or cfg.get('model_server')
api_base = (api_base or '').strip()
api_key = cfg.get('api_key')
api_key = api_key or os.getenv('OPENAI_API_KEY')
api_key = (api_key or 'EMPTY').strip()
if openai.__version__.startswith('0.'):
if api_base:
openai.api_base = api_base
if api_key:
openai.api_key = api_key
self._complete_create = openai.Completion.create
self._chat_complete_create = openai.ChatCompletion.create
else:
api_kwargs = {}
if api_base:
api_kwargs['base_url'] = api_base
if api_key:
api_kwargs['api_key'] = api_key
def _chat_complete_create(*args, **kwargs):
# OpenAI API v1 does not allow the following args, must pass by extra_body
extra_params = ['top_k', 'repetition_penalty']
if any((k in kwargs) for k in extra_params):
kwargs['extra_body'] = copy.deepcopy(kwargs.get('extra_body', {}))
for k in extra_params:
if k in kwargs:
kwargs['extra_body'][k] = kwargs.pop(k)
if 'request_timeout' in kwargs:
kwargs['timeout'] = kwargs.pop('request_timeout')
client = openai.OpenAI(**api_kwargs)
return client.chat.completions.create(*args, **kwargs)
def _complete_create(*args, **kwargs):
# OpenAI API v1 does not allow the following args, must pass by extra_body
extra_params = ['top_k', 'repetition_penalty']
if any((k in kwargs) for k in extra_params):
kwargs['extra_body'] = copy.deepcopy(kwargs.get('extra_body', {}))
for k in extra_params:
if k in kwargs:
kwargs['extra_body'][k] = kwargs.pop(k)
if 'request_timeout' in kwargs:
kwargs['timeout'] = kwargs.pop('request_timeout')
client = openai.OpenAI(**api_kwargs)
return client.completions.create(*args, **kwargs)
self._complete_create = _complete_create
self._chat_complete_create = _chat_complete_create
def _chat_stream(
self,
messages: List[Message],
delta_stream: bool,
generate_cfg: dict,
) -> Iterator[List[Message]]:
messages = self.convert_messages_to_dicts(messages)
try:
MAX_RETRIES = 5
INITIAL_DELAY = 2
CONTENT_THRESHOLD = 50
REASONING_THRESHOLD = 50
response = None
for attempt in range(MAX_RETRIES):
try:
response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **generate_cfg)
break
except RateLimitError as ex:
if attempt == MAX_RETRIES - 1:
logger.error(f"API rate limit error after {MAX_RETRIES} retries. Raising exception.")
raise ModelServiceError(exception=ex) from ex
delay = INITIAL_DELAY * (2 ** attempt) + random.uniform(0, 1)
logger.warning(
f"Rate limit exceeded. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{MAX_RETRIES})"
)
time.sleep(delay)
except OpenAIError as ex:
logger.error(f"An OpenAI error occurred: {ex}")
raise ModelServiceError(exception=ex) from ex
# response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **generate_cfg)
if delta_stream:
for chunk in response:
if chunk.choices:
choice = chunk.choices[0]
if hasattr(choice.delta, 'reasoning_content') and choice.delta.reasoning_content:
yield [
Message(
role=ASSISTANT,
content='',
reasoning_content=choice.delta.reasoning_content
)
]
if hasattr(choice.delta, 'content') and choice.delta.content:
yield [Message(role=ASSISTANT, content=choice.delta.content, reasoning_content='')]
if hasattr(choice.delta, 'tool_calls') and choice.delta.tool_calls:
function_name = choice.delta.tool_calls[0].function.name
function_call = {
'name': function_name,
'arguments': json.loads(choice.delta.tool_calls[0].function.arguments)
}
function_json = json.dumps(function_call, ensure_ascii=False)
yield [Message(role=ASSISTANT, content=f'<tool_call>{function_json}</tool_call>')]
logger.info(f'delta_stream message chunk: {chunk}')
else:
full_response = ''
full_reasoning_content = ''
content_buffer = ''
reasoning_content_buffer = ''
for chunk in response:
if not chunk.choices:
continue
choice = chunk.choices[0]
new_content = choice.delta.content if hasattr(choice.delta, 'content') and choice.delta.content else ''
new_reasoning = choice.delta.reasoning if hasattr(choice.delta, 'reasoning') and choice.delta.reasoning else ''
has_tool_calls = hasattr(choice.delta, 'tool_calls') and choice.delta.tool_calls
if new_reasoning:
full_reasoning_content += new_reasoning
reasoning_content_buffer += new_reasoning
if new_content:
full_response += new_content
content_buffer += new_content
if has_tool_calls:
function_name = choice.delta.tool_calls[0].function.name
function_call = {
'name': function_name,
'arguments': json.loads(choice.delta.tool_calls[0].function.arguments)
}
function_json = json.dumps(function_call, ensure_ascii=False)
logger.info(json.dumps(function_call, ensure_ascii=False, indent=4))
full_response += f'<tool_call>{function_json}</tool_call>'
content_buffer += '<tool_call>'
if (len(content_buffer) >= CONTENT_THRESHOLD or
len(reasoning_content_buffer) >= REASONING_THRESHOLD or
'\n' in new_content or
'\n' in new_reasoning):
yield [Message(role=ASSISTANT, content=full_response, reasoning_content=full_reasoning_content)]
content_buffer = ''
reasoning_content_buffer = ''
logger.info(f'message chunk: {chunk}')
if content_buffer or reasoning_content_buffer:
yield [Message(role=ASSISTANT, content=full_response, reasoning_content=full_reasoning_content)]
except OpenAIError as ex:
raise ModelServiceError(exception=ex)
def _chat_no_stream(
self,
messages: List[Message],
generate_cfg: dict,
) -> List[Message]:
messages = self.convert_messages_to_dicts(messages)
try:
response = self._chat_complete_create(model=self.model, messages=messages, stream=False, **generate_cfg)
if hasattr(response.choices[0].message, 'reasoning_content'):
return [
Message(role=ASSISTANT,
content=response.choices[0].message.content,
reasoning_content=response.choices[0].message.reasoning_content)
]
else:
return [Message(role=ASSISTANT, content=response.choices[0].message.content)]
except OpenAIError as ex:
raise ModelServiceError(exception=ex)
def _chat_with_functions(
self,
messages: List[Message],
functions: List[Dict],
stream: bool,
delta_stream: bool,
generate_cfg: dict,
lang: Literal['en', 'zh'],
) -> Union[List[Message], Iterator[List[Message]]]:
# if delta_stream:
# raise NotImplementedError('Please use stream=True with delta_stream=False, because delta_stream=True'
# ' is not implemented for function calling due to some technical reasons.')
generate_cfg = copy.deepcopy(generate_cfg)
for k in ['parallel_function_calls', 'function_choice', 'thought_in_content']:
if k in generate_cfg:
del generate_cfg[k]
messages = simulate_response_completion_with_chat(messages)
return self._chat(messages, stream=stream, delta_stream=delta_stream, generate_cfg=generate_cfg)
def _chat(
self,
messages: List[Union[Message, Dict]],
stream: bool,
delta_stream: bool,
generate_cfg: dict,
) -> Union[List[Message], Iterator[List[Message]]]:
if stream:
return self._chat_stream(messages, delta_stream=delta_stream, generate_cfg=generate_cfg)
else:
return self._chat_no_stream(messages, generate_cfg=generate_cfg)
@staticmethod
def convert_messages_to_dicts(messages: List[Message]) -> List[dict]:
# TODO: Change when the VLLM deployed model needs to pass reasoning_complete.
# At this time, in order to be compatible with lower versions of vLLM,
# and reasoning content is currently not useful
messages = [msg.model_dump() for msg in messages]
return_messages = []
messages[0]["content"] = SYSTEM_PROMPT + "Current date: " + str(today_date())
for i in messages:
i["content"] = i["content"].replace("<think>\n<think>\n","<think>\n\n")
return_messages.append(i)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
return return_messages
|