Spaces:
Running
Running
import warnings | |
import typer | |
from typing_extensions import Annotated, List | |
from rag.agents.interface import get_pipeline | |
import tempfile | |
import os | |
from rich import print | |
# Disable parallelism in the Huggingface tokenizers library to prevent potential deadlocks and ensure consistent behavior. | |
# This is especially important in environments where multiprocessing is used, as forking after parallelism can lead to issues. | |
# Note: Disabling parallelism may impact performance, but it ensures safer and more predictable execution. | |
os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
def run(inputs: Annotated[str, typer.Argument(help="The list of fields to fetch")], | |
types: Annotated[str, typer.Argument(help="The list of types of the fields")] = None, | |
keywords: Annotated[str, typer.Argument(help="The list of table column keywords")] = None, | |
file_path: Annotated[str, typer.Option(help="The file to process")] = None, | |
agent: Annotated[str, typer.Option(help="Selected agent")] = "llamaindex", | |
index_name: Annotated[str, typer.Option(help="Index to identify embeddings")] = None, | |
options: Annotated[List[str], typer.Option(help="Options to pass to the agent")] = None, | |
group_by_rows: Annotated[bool, typer.Option(help="Group JSON collection by rows")] = True, | |
update_targets: Annotated[bool, typer.Option(help="Update targets")] = True, | |
debug: Annotated[bool, typer.Option(help="Enable debug mode")] = False): | |
query = 'retrieve ' + inputs | |
query_types = types | |
query_inputs_arr = [param.strip() for param in inputs.split(',')] if query_types else [] | |
query_types_arr = [param.strip() for param in query_types.split(',')] if query_types else [] | |
keywords_arr = [param.strip() for param in keywords.split(',')] if keywords is not None else None | |
if not query_types: | |
query = inputs | |
user_selected_agent = agent # Modify this as needed | |
try: | |
rag = get_pipeline(user_selected_agent) | |
answer = rag.run_pipeline(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query, file_path, | |
index_name, options, group_by_rows, update_targets, debug) | |
print(f"\nJSON response:\n") | |
print(answer) | |
except ValueError as e: | |
print(f"Caught an exception: {e}") | |
async def run_from_api_engine(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query, index_name, | |
options_arr, file, group_by_rows, update_targets, debug): | |
try: | |
rag = get_pipeline(user_selected_agent) | |
if file is not None: | |
with tempfile.TemporaryDirectory() as temp_dir: | |
temp_file_path = os.path.join(temp_dir, file.filename) | |
# Save the uploaded file to the temporary directory | |
with open(temp_file_path, 'wb') as temp_file: | |
content = await file.read() | |
temp_file.write(content) | |
answer = rag.run_pipeline(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query, | |
temp_file_path, index_name, options_arr, group_by_rows, update_targets, | |
debug, False) | |
else: | |
answer = rag.run_pipeline(user_selected_agent, query_inputs_arr, query_types_arr, keywords_arr, query, | |
None, index_name, options_arr, group_by_rows, update_targets, | |
debug, False) | |
except ValueError as e: | |
raise e | |
return answer | |
if __name__ == "__main__": | |
typer.run(run) | |