File size: 3,793 Bytes
42cd5f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import warnings
import typer
from typing_extensions import Annotated, List
from rag.agents.interface import get_pipeline
import tempfile
import os
from rich import print


# Disable parallelism in the Huggingface tokenizers library to prevent potential deadlocks and ensure consistent behavior.
# This is especially important in environments where multiprocessing is used, as forking after parallelism can lead to issues.
# Note: Disabling parallelism may impact performance, but it ensures safer and more predictable execution.
os.environ['TOKENIZERS_PARALLELISM'] = 'false'


warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)


def run(inputs: Annotated[str, typer.Argument(help="The list of fields to fetch")],
        types: Annotated[str, typer.Argument(help="The list of types of the fields")] = None,
        keywords: Annotated[str, typer.Argument(help="The list of table column keywords")] = None,
        file_path: Annotated[str, typer.Option(help="The file to process")] = None,
        agent: Annotated[str, typer.Option(help="Selected agent")] = "llamaindex",
        index_name: Annotated[str, typer.Option(help="Index to identify embeddings")] = None,
        options: Annotated[List[str], typer.Option(help="Options to pass to the agent")] = None,
        group_by_rows: Annotated[bool, typer.Option(help="Group JSON collection by rows")] = True,
        update_targets: Annotated[bool, typer.Option(help="Update targets")] = True,
        debug: Annotated[bool, typer.Option(help="Enable debug mode")] = False):

    query = 'retrieve ' + inputs
    query_types = types

    query_inputs_arr = [param.strip() for param in inputs.split(',')] if query_types else []
    query_types_arr = [param.strip() for param in query_types.split(',')] if query_types else []
    keywords_arr = [param.strip() for param in keywords.split(',')] if keywords is not None else None

    if not query_types:
        query = inputs

    user_selected_agent = agent  # Modify this as needed

    try:
        rag = get_pipeline(user_selected_agent)
        answer = rag.run_pipeline(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query, file_path,
                                  index_name, options, group_by_rows, update_targets, debug)

        print(f"\nJSON response:\n")
        print(answer)
    except ValueError as e:
        print(f"Caught an exception: {e}")


async def run_from_api_engine(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query, index_name,
                              options_arr, file, group_by_rows, update_targets, debug):
    try:
        rag = get_pipeline(user_selected_agent)

        if file is not None:
            with tempfile.TemporaryDirectory() as temp_dir:
                temp_file_path = os.path.join(temp_dir, file.filename)

                # Save the uploaded file to the temporary directory
                with open(temp_file_path, 'wb') as temp_file:
                    content = await file.read()
                    temp_file.write(content)

                answer = rag.run_pipeline(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query,
                                          temp_file_path, index_name, options_arr, group_by_rows, update_targets,
                                          debug, False)
        else:
            answer = rag.run_pipeline(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query,
                                      None, index_name, options_arr, group_by_rows, update_targets,
                                      debug, False)
    except ValueError as e:
        raise e

    return answer


if __name__ == "__main__":
    typer.run(run)