Spaces:
Running
on
Zero
Running
on
Zero
| import re | |
| import json | |
| import uuid | |
| import warnings | |
| from abc import ABC | |
| from typing import ( | |
| Any, | |
| AsyncIterator, | |
| Callable, | |
| Dict, | |
| List, | |
| Optional, | |
| Sequence, | |
| Tuple, | |
| Type, | |
| Union, | |
| cast, | |
| ) | |
| from langchain_core.callbacks import ( | |
| AsyncCallbackManagerForLLMRun, | |
| CallbackManagerForLLMRun, | |
| ) | |
| from langchain_core.language_models import BaseChatModel, LanguageModelInput | |
| from langchain_core.messages import ( | |
| SystemMessage, | |
| AIMessage, | |
| BaseMessage, | |
| BaseMessageChunk, | |
| ToolCall, | |
| ) | |
| from langchain_core.outputs import ChatGeneration, ChatResult | |
| from langchain_core.prompts import SystemMessagePromptTemplate | |
| from pydantic import BaseModel | |
| from langchain_core.runnables import Runnable, RunnableConfig | |
| from langchain_core.tools import BaseTool | |
| from langchain_core.utils.function_calling import convert_to_openai_tool | |
| DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools: | |
| {tools} | |
| You must always select one of the above tools and respond with only a JSON object matching the following schema: | |
| {{ | |
| "tool": <name of selected tool 1>, | |
| "tool_input": <parameters for selected tool 1, matching the tool's JSON schema> | |
| }}, | |
| {{ | |
| "tool": <name of selected tool 2>, | |
| "tool_input": <parameters for selected tool 2, matching the tool's JSON schema> | |
| }} | |
| """ # noqa: E501 | |
| def extract_think(content): | |
| # Added by Cursor 20250726 jmd | |
| # Extract content within <think>...</think> | |
| think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL) | |
| think_text = think_match.group(1).strip() if think_match else "" | |
| # Extract text after </think> | |
| if think_match: | |
| post_think = content[think_match.end() :].lstrip() | |
| else: | |
| # Check if content starts with <think> but missing closing tag | |
| if content.strip().startswith("<think>"): | |
| # Extract everything after <think> | |
| think_start = content.find("<think>") + len("<think>") | |
| think_text = content[think_start:].strip() | |
| post_think = "" | |
| else: | |
| # No <think> found, so return entire content as post_think | |
| post_think = content | |
| return think_text, post_think | |
| class ToolCallingLLM(BaseChatModel, ABC): | |
| """ToolCallingLLM mixin to enable tool calling features on non tool calling models. | |
| Note: This is an incomplete mixin and should not be used directly. It must be used to extent an existing Chat Model. | |
| Setup: | |
| Install dependencies for your Chat Model. | |
| Any API Keys or setup needed for your Chat Model is still applicable. | |
| Key init args β completion params: | |
| Refer to the documentation of the Chat Model you wish to extend with Tool Calling. | |
| Key init args β client params: | |
| Refer to the documentation of the Chat Model you wish to extend with Tool Calling. | |
| See full list of supported init args and their descriptions in the params section. | |
| Instantiate: | |
| ``` | |
| # Example implementation using LiteLLM | |
| from langchain_community.chat_models import ChatLiteLLM | |
| class LiteLLMFunctions(ToolCallingLLM, ChatLiteLLM): | |
| def __init__(self, **kwargs: Any) -> None: | |
| super().__init__(**kwargs) | |
| @property | |
| def _llm_type(self) -> str: | |
| return "litellm_functions" | |
| llm = LiteLLMFunctions(model="ollama/phi3") | |
| ``` | |
| Invoke: | |
| ``` | |
| messages = [ | |
| ("human", "What is the capital of France?") | |
| ] | |
| llm.invoke(messages) | |
| ``` | |
| ``` | |
| AIMessage(content='The capital of France is Paris.', id='run-497d0e1a-d63b-45e8-9c8b-5e76d99b9468-0') | |
| ``` | |
| Tool calling: | |
| ``` | |
| from pydantic import BaseModel, Field | |
| class GetWeather(BaseModel): | |
| '''Get the current weather in a given location''' | |
| location: str = Field(..., description="The city and state, e.g. San Francisco, CA") | |
| class GetPopulation(BaseModel): | |
| '''Get the current population in a given location''' | |
| location: str = Field(..., description="The city and state, e.g. San Francisco, CA") | |
| llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) | |
| ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") | |
| ai_msg.tool_calls | |
| ``` | |
| ``` | |
| [{'name': 'GetWeather', 'args': {'location': 'Austin, TX'}, 'id': 'call_25ed526917b94d8fa5db3fe30a8cf3c0'}] | |
| ``` | |
| Response metadata | |
| Refer to the documentation of the Chat Model you wish to extend with Tool Calling. | |
| """ # noqa: E501 | |
| tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE | |
| def __init__(self, **kwargs: Any) -> None: | |
| super().__init__(**kwargs) | |
| def _generate_system_message_and_functions( | |
| self, | |
| kwargs: Dict[str, Any], | |
| ) -> Tuple[BaseMessage, List]: | |
| functions = kwargs.get("tools", []) | |
| # Convert functions to OpenAI tool schema | |
| functions = [convert_to_openai_tool(fn) for fn in functions] | |
| # Create system message with tool descriptions | |
| system_message_prompt_template = SystemMessagePromptTemplate.from_template( | |
| self.tool_system_prompt_template | |
| ) | |
| system_message = system_message_prompt_template.format( | |
| tools=json.dumps(functions, indent=2) | |
| ) | |
| return system_message, functions | |
| def _process_response( | |
| self, response_message: BaseMessage, functions: List[Dict] | |
| ) -> AIMessage: | |
| if not isinstance(response_message.content, str): | |
| raise ValueError("ToolCallingLLM does not support non-string output.") | |
| # Extract <think>...</think> content and text after </think> for further processing 20250726 jmd | |
| think_text, post_think = extract_think(response_message.content) | |
| ## For debugging | |
| # print("post_think") | |
| # print(post_think) | |
| # Remove backticks around code blocks | |
| post_think = re.sub(r"^```json", "", post_think) | |
| post_think = re.sub(r"^```", "", post_think) | |
| post_think = re.sub(r"```$", "", post_think) | |
| # Remove intervening backticks from adjacent code blocks | |
| post_think = re.sub(r"```\n```json", ",", post_think) | |
| # Remove trailing comma (if there is one) | |
| post_think = post_think.rstrip(",") | |
| # Parse output for JSON (support multiple objects separated by commas) | |
| try: | |
| # Works for one JSON object, or multiple JSON objects enclosed in "[]" | |
| parsed_json_results = json.loads(f"{post_think}") | |
| if not isinstance(parsed_json_results, list): | |
| parsed_json_results = [parsed_json_results] | |
| except: | |
| try: | |
| # Works for multiple JSON objects not enclosed in "[]" | |
| parsed_json_results = json.loads(f"[{post_think}]") | |
| except json.JSONDecodeError: | |
| # Return entire response if JSON wasn't parsed or is missing | |
| return AIMessage(content=response_message.content) | |
| # print("parsed_json_results") | |
| # print(parsed_json_results) | |
| tool_calls = [] | |
| for parsed_json_result in parsed_json_results: | |
| # Get tool name from output | |
| called_tool_name = ( | |
| parsed_json_result["tool"] | |
| if "tool" in parsed_json_result | |
| else ( | |
| parsed_json_result["name"] if "name" in parsed_json_result else None | |
| ) | |
| ) | |
| # Check if tool name is in functions list | |
| called_tool = next( | |
| (fn for fn in functions if fn["function"]["name"] == called_tool_name), | |
| None, | |
| ) | |
| if called_tool is None: | |
| # Issue a warning and skip this tool call | |
| warnings.warn(f"Called tool ({called_tool_name}) not in functions list") | |
| continue | |
| # Get tool arguments from output | |
| called_tool_arguments = ( | |
| parsed_json_result["tool_input"] | |
| if "tool_input" in parsed_json_result | |
| else ( | |
| parsed_json_result["parameters"] | |
| if "parameters" in parsed_json_result | |
| else {} | |
| ) | |
| ) | |
| tool_calls.append( | |
| ToolCall( | |
| name=called_tool_name, | |
| args=called_tool_arguments, | |
| id=f"call_{str(uuid.uuid4()).replace('-', '')}", | |
| ) | |
| ) | |
| if not tool_calls: | |
| # If nothing valid, return original content | |
| return AIMessage(content=response_message.content) | |
| # Put together response message | |
| response_message = AIMessage( | |
| content=f"<think>\n{think_text}\n</think>", | |
| tool_calls=tool_calls, | |
| ) | |
| return response_message | |
| def _generate( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> ChatResult: | |
| system_message, functions = self._generate_system_message_and_functions(kwargs) | |
| response_message = super()._generate( # type: ignore[safe-super] | |
| [system_message] + messages, stop=stop, run_manager=run_manager, **kwargs | |
| ) | |
| response = self._process_response( | |
| response_message.generations[0].message, functions | |
| ) | |
| return ChatResult(generations=[ChatGeneration(message=response)]) | |
| async def _agenerate( | |
| self, | |
| messages: List[BaseMessage], | |
| stop: Optional[List[str]] = None, | |
| run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | |
| **kwargs: Any, | |
| ) -> ChatResult: | |
| system_message, functions = self._generate_system_message_and_functions(kwargs) | |
| response_message = await super()._agenerate( | |
| [system_message] + messages, stop=stop, run_manager=run_manager, **kwargs | |
| ) | |
| response = self._process_response( | |
| response_message.generations[0].message, functions | |
| ) | |
| return ChatResult(generations=[ChatGeneration(message=response)]) | |
| async def astream( | |
| self, | |
| input: LanguageModelInput, | |
| config: Optional[RunnableConfig] = None, | |
| *, | |
| stop: Optional[List[str]] = None, | |
| **kwargs: Any, | |
| ) -> AsyncIterator[BaseMessageChunk]: | |
| system_message, functions = self._generate_system_message_and_functions(kwargs) | |
| generation: Optional[BaseMessageChunk] = None | |
| async for chunk in super().astream( | |
| [system_message] + super()._convert_input(input).to_messages(), | |
| stop=stop, | |
| **kwargs, | |
| ): | |
| if generation is None: | |
| generation = chunk | |
| else: | |
| generation += chunk | |
| assert generation is not None | |
| response = self._process_response(generation, functions) | |
| yield cast(BaseMessageChunk, response) | |