File size: 4,561 Bytes
d519be4
cb3dcae
d519be4
cb3dcae
 
d519be4
 
cb3dcae
d519be4
 
cb3dcae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d519be4
 
 
 
 
 
3f5217a
 
 
d519be4
cb3dcae
 
 
 
 
 
 
 
 
aa6ef3d
4701923
d519be4
aa6ef3d
d519be4
cb34a9e
d519be4
cb3dcae
 
 
d519be4
aa6ef3d
d519be4
 
cb3dcae
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

from abc import ABC, abstractmethod
from typing import Dict, Union, get_origin, get_args
from pydantic import BaseModel, Field
from types import UnionType
import os
import logging

from src.vectorstore import VectorStore
# from langchain.tools import tool


class ToolBase(BaseModel, ABC):
    @abstractmethod
    def invoke(cls, input: Dict):
        pass

    @classmethod
    def to_openai_tool(cls):
        """
        Extracts function metadata from a Pydantic class, including function name, parameters, and descriptions.
        Formats it into a structure similar to OpenAI's function metadata.
        """
        function_metadata = {
            "type": "function",
            "function": {
                "name": cls.__name__,  # Function name is same as the class name, in lowercase
                "description": cls.__doc__.strip(),
                "parameters": {
                    "type": "object",
                    "properties": {},
                    "required": [],
                },
            },
        }

        # Iterate over the fields to add them to the parameters
        for field_name, field_info in cls.model_fields.items():
            # Field properties
            field_type = "string"  # Default to string, will adjust if it's a different type
            annotation = field_info.annotation.__args__[0] if getattr(field_info.annotation, "__origin__", None) is Union else field_info.annotation
            
            has_none = False
            if get_origin(annotation) is UnionType:  # Check if it's a Union type
                args = get_args(annotation)
                if type(None) in args:
                    has_none = True
                args = [arg for arg in args if type(None) != arg]
                if len(args) > 1:
                    raise TypeError("It can be union of only a valid type (str, int, bool, etc) and None")
                elif len(args) == 0:
                    raise TypeError("There must be a valid type (str, int, bool, etc) not only None")
                else:
                    annotation = args[0]
            
            if annotation == int:
                field_type = "integer"
            elif annotation == bool:
                field_type = "boolean"
            
            # Add the field's description and type to the properties
            function_metadata["function"]["parameters"]["properties"][field_name] = {
                "type": field_type,
                "description": field_info.description,
            }

            # Determine if the field is required (not Optional or None)
            if field_info.is_required():
                function_metadata["function"]["parameters"]["required"].append(field_name)
                has_none = True

            # If there's an enum (like for `unit`), add it to the properties
            if hasattr(field_info, 'default') and field_info.default is not None and isinstance(field_info.default, list):
                function_metadata["function"]["parameters"]["properties"][field_name]["enum"] = field_info.default
                if not has_none:
                    function_metadata["function"]["parameters"]["required"].append(field_name)

        return function_metadata

tools: Dict[str, ToolBase] = {}
oitools = []


vector_store = VectorStore(
    # embeddings_model="BAAI/bge-m3", 
    embeddings_model=os.environ.get("EMBEDDINGS_MODEL"),
    vs_local_path=os.environ.get("VS_LOCAL_PATH"), 
    vs_hf_path=os.environ.get("VS_HF_PATH"),
    number_of_contexts=int(os.environ.get("RETRIEVE_NUM_CONTEXTS", 3))
)


def tool_register(cls: BaseModel):
    oaitool = cls.to_openai_tool()
    
    oitools.append(oaitool)
    tools[oaitool["function"]["name"]] = cls


@tool_register
class retrieve_wiki_data(ToolBase):
    """Retrieves relevant information from wikipedia, based on the user's query."""
    
    logging.info("@tool_register: retrieve_wiki_data()")  
    
    query: str = Field(description="The user's input or question, used to search Wikipedia.")
    logging.info(f"query: {query}")

    @classmethod
    def invoke(cls, input: Dict) -> str:

        logging.info(f"retrieve_wiki_data.invoke() input: {input}")
        # Check if the input is a dictionary

        query = input.get("query", None)
        if not query:
            return "Missing required argument: query."
        
        # return "We are currently working on it. You can't use this tool right now—please try again later. Thank you for your patience!"
        return vector_store.get_context(query)