File size: 2,493 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Any, List, Optional, Type

from langchain_core.messages import ToolCall
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel


class ToolsOutputParser(BaseGenerationOutputParser):
    first_tool_only: bool = False
    args_only: bool = False
    pydantic_schemas: Optional[List[Type[BaseModel]]] = None

    class Config:
        extra = "forbid"

    def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
        """Parse a list of candidate model Generations into a specific format.

        Args:
            result: A list of Generations to be parsed. The Generations are assumed
                to be different candidate outputs for a single model input.

        Returns:
            Structured output.
        """
        if not result or not isinstance(result[0], ChatGeneration):
            return None if self.first_tool_only else []
        message = result[0].message
        if isinstance(message.content, str):
            tool_calls: List = []
        else:
            content: List = message.content
            _tool_calls = [dict(tc) for tc in extract_tool_calls(content)]
            # Map tool call id to index
            id_to_index = {
                block["id"]: i
                for i, block in enumerate(content)
                if block["type"] == "tool_use"
            }
            tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
        if self.pydantic_schemas:
            tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
        elif self.args_only:
            tool_calls = [tc["args"] for tc in tool_calls]
        else:
            pass

        if self.first_tool_only:
            return tool_calls[0] if tool_calls else None
        else:
            return [tool_call for tool_call in tool_calls]

    def _pydantic_parse(self, tool_call: dict) -> BaseModel:
        cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
            tool_call["name"]
        ]
        return cls_(**tool_call["args"])


def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
    tool_calls = []
    for block in content:
        if block["type"] != "tool_use":
            continue
        tool_calls.append(
            ToolCall(name=block["name"], args=block["input"], id=block["id"])
        )
    return tool_calls