Spaces:
Running
Running
File size: 3,793 Bytes
42cd5f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import warnings
import typer
from typing_extensions import Annotated, List
from rag.agents.interface import get_pipeline
import tempfile
import os
from rich import print
# Disable parallelism in the Huggingface tokenizers library to prevent potential deadlocks and ensure consistent behavior.
# This is especially important in environments where multiprocessing is used, as forking after parallelism can lead to issues.
# Note: Disabling parallelism may impact performance, but it ensures safer and more predictable execution.
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
def run(inputs: Annotated[str, typer.Argument(help="The list of fields to fetch")],
types: Annotated[str, typer.Argument(help="The list of types of the fields")] = None,
keywords: Annotated[str, typer.Argument(help="The list of table column keywords")] = None,
file_path: Annotated[str, typer.Option(help="The file to process")] = None,
agent: Annotated[str, typer.Option(help="Selected agent")] = "llamaindex",
index_name: Annotated[str, typer.Option(help="Index to identify embeddings")] = None,
options: Annotated[List[str], typer.Option(help="Options to pass to the agent")] = None,
group_by_rows: Annotated[bool, typer.Option(help="Group JSON collection by rows")] = True,
update_targets: Annotated[bool, typer.Option(help="Update targets")] = True,
debug: Annotated[bool, typer.Option(help="Enable debug mode")] = False):
query = 'retrieve ' + inputs
query_types = types
query_inputs_arr = [param.strip() for param in inputs.split(',')] if query_types else []
query_types_arr = [param.strip() for param in query_types.split(',')] if query_types else []
keywords_arr = [param.strip() for param in keywords.split(',')] if keywords is not None else None
if not query_types:
query = inputs
user_selected_agent = agent # Modify this as needed
try:
rag = get_pipeline(user_selected_agent)
answer = rag.run_pipeline(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query, file_path,
index_name, options, group_by_rows, update_targets, debug)
print(f"\nJSON response:\n")
print(answer)
except ValueError as e:
print(f"Caught an exception: {e}")
async def run_from_api_engine(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query, index_name,
options_arr, file, group_by_rows, update_targets, debug):
try:
rag = get_pipeline(user_selected_agent)
if file is not None:
with tempfile.TemporaryDirectory() as temp_dir:
temp_file_path = os.path.join(temp_dir, file.filename)
# Save the uploaded file to the temporary directory
with open(temp_file_path, 'wb') as temp_file:
content = await file.read()
temp_file.write(content)
answer = rag.run_pipeline(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query,
temp_file_path, index_name, options_arr, group_by_rows, update_targets,
debug, False)
else:
answer = rag.run_pipeline(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query,
None, index_name, options_arr, group_by_rows, update_targets,
debug, False)
except ValueError as e:
raise e
return answer
if __name__ == "__main__":
typer.run(run)
|