Spaces:
Runtime error
Runtime error
langchain-qa-bot
/
docs
/langchain
/libs
/partners
/anthropic
/langchain_anthropic
/output_parsers.py
from typing import Any, List, Optional, Type | |
from langchain_core.messages import ToolCall | |
from langchain_core.output_parsers import BaseGenerationOutputParser | |
from langchain_core.outputs import ChatGeneration, Generation | |
from langchain_core.pydantic_v1 import BaseModel | |
class ToolsOutputParser(BaseGenerationOutputParser): | |
first_tool_only: bool = False | |
args_only: bool = False | |
pydantic_schemas: Optional[List[Type[BaseModel]]] = None | |
class Config: | |
extra = "forbid" | |
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: | |
"""Parse a list of candidate model Generations into a specific format. | |
Args: | |
result: A list of Generations to be parsed. The Generations are assumed | |
to be different candidate outputs for a single model input. | |
Returns: | |
Structured output. | |
""" | |
if not result or not isinstance(result[0], ChatGeneration): | |
return None if self.first_tool_only else [] | |
message = result[0].message | |
if isinstance(message.content, str): | |
tool_calls: List = [] | |
else: | |
content: List = message.content | |
_tool_calls = [dict(tc) for tc in extract_tool_calls(content)] | |
# Map tool call id to index | |
id_to_index = { | |
block["id"]: i | |
for i, block in enumerate(content) | |
if block["type"] == "tool_use" | |
} | |
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls] | |
if self.pydantic_schemas: | |
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls] | |
elif self.args_only: | |
tool_calls = [tc["args"] for tc in tool_calls] | |
else: | |
pass | |
if self.first_tool_only: | |
return tool_calls[0] if tool_calls else None | |
else: | |
return [tool_call for tool_call in tool_calls] | |
def _pydantic_parse(self, tool_call: dict) -> BaseModel: | |
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[ | |
tool_call["name"] | |
] | |
return cls_(**tool_call["args"]) | |
def extract_tool_calls(content: List[dict]) -> List[ToolCall]: | |
tool_calls = [] | |
for block in content: | |
if block["type"] != "tool_use": | |
continue | |
tool_calls.append( | |
ToolCall(name=block["name"], args=block["input"], id=block["id"]) | |
) | |
return tool_calls | |