Spaces:
Sleeping
Sleeping
import gc | |
import logging | |
import os | |
import tempfile | |
from typing import Optional | |
import torch | |
from dotenv import load_dotenv | |
from langchain.agents import AgentExecutor, create_tool_calling_agent | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.rate_limiters import InMemoryRateLimiter | |
from langchain_core.tools import Tool | |
from langchain_experimental.utilities import PythonREPL | |
# from langchain_community.tools import DuckDuckGoSearchResults | |
# from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper | |
# from langchain_google_community import GoogleSearchAPIWrapper, GoogleSearchResults | |
from langchain_ollama import ChatOllama | |
from src.final_answer import create_final_answer_graph, validate_answer | |
from src.tools import ( | |
analyze_csv_file, | |
analyze_excel_file, | |
download_file_from_url, | |
duckduckgo_search, | |
extract_text_from_image, | |
read_file, | |
reverse_decoder, | |
review_youtube_video, | |
transcribe_audio, | |
transcribe_youtube, | |
use_vision_model, | |
video_frames_to_images, | |
website_scrape, | |
) | |
logger = logging.getLogger(__name__) | |
load_dotenv() | |
base_url = os.getenv("OLLAMA_BASE_URL") | |
rate_limiter = InMemoryRateLimiter(requests_per_second=0.1) | |
class BasicAgent: | |
def __init__(self): | |
try: | |
logger.info("Initializing BasicAgent") | |
# Create the prompt template | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
( | |
"system", | |
"""You are a general AI assistant. I will ask you a | |
question. Report your thoughts, and finish your answer | |
with the following template: FINAL ANSWER: [YOUR FINAL | |
ANSWER]. YOUR FINAL ANSWER should be a number OR as few | |
words as possible OR a comma separated list of numbers | |
and/or strings. If you are asked for a number, don't | |
use comma to write your number neither use units such | |
as $ or percent sign unless specified otherwise. If you | |
are asked for a string, don't use articles, neither | |
abbreviations (e.g. for cities), and write the digits | |
in plain text unless specified otherwise. If you are | |
asked for a comma separated list, apply the above rules | |
depending of whether the element to be put in the list | |
is a number or a string. | |
""", | |
), | |
("placeholder", "{chat_history}"), | |
("human", "{input}"), | |
("placeholder", "{agent_scratchpad}"), | |
] | |
) | |
logger.info("Created prompt template") | |
llm = ChatOllama( | |
model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K", | |
base_url=base_url, | |
temperature=0.2, | |
) | |
logger.info("Created model successfully") | |
# Define available tools | |
tools = [ | |
Tool( | |
name="DuckDuckGoSearchResults", | |
description="""Performs a live search using DuckDuckGo | |
and analyzes the top results. Returns a summary including | |
result titles, URLs, brief snippets, and ranking | |
positions. Use this to quickly assess the relevance, | |
diversity, and quality of information retrieved from a | |
privacy-focused search engine, without personalized or | |
biased filtering.""", | |
# func=DuckDuckGoSearchResults( | |
# api_wrapper=DuckDuckGoSearchAPIWrapper() | |
# ).run, | |
func=duckduckgo_search, | |
), | |
# Tool( | |
# name="GoogleSearchResults", | |
# description="""Performs a live Google search and analyzes | |
# the top results. Returns a summary including result titles, | |
# URLs, brief snippets, and ranking positions. Use this to | |
# quickly understand the relevance, variety, and quality of | |
# search results for a given query before deeper research or | |
# content planning.""", | |
# func=GoogleSearchResults( | |
# api_wrapper=GoogleSearchAPIWrapper( | |
# google_api_key=os.getenv("GOOGLE_SEARCH_API_KEY"), | |
# google_cse_id=os.getenv("GOOGLE_CSE_ID"), | |
# k=5, # Number of results to return | |
# ) | |
# ).run, | |
# ), | |
Tool( | |
name="analyze csv file", | |
description="""Only read and analyze the contents of a CSV | |
file if one is explicitly referenced or uploaded in the | |
question. When a CSV file is provided, return a summary of | |
the dataset, including column names, data types, missing | |
value counts, basic statistics for numeric fields, and a | |
preview of the data. Use this only to quickly understand | |
the structure and quality of the dataset before performing | |
any further analysis.""", | |
func=analyze_csv_file, | |
), | |
Tool( | |
name="analyze excel file", | |
description="""Reads and analyzes the contents of an Excel | |
file (.xlsx or .xls). Returns structured summaries | |
for each sheet, including column names, data types, missing | |
value counts, basic statistics for numeric columns, and | |
sample rows. Use this to quickly explore the structure and | |
quality of Excel datasets.""", | |
func=analyze_excel_file, | |
), | |
Tool( | |
name="download file from url", | |
description="""Downloads a file from a given URL and saves | |
it locally. Supports various file types such as CSV, Excel, | |
images, and PDFs. Use this to retrieve external resources | |
for processing or analysis.""", | |
func=download_file_from_url, | |
), | |
Tool( | |
name="extract_text_from_image", | |
description="""Performs Optical Character Recognition (OCR) | |
on an image to extract readable text after downloading it. | |
Supports common image formats (e.g., PNG, JPG). Use this to | |
digitize printed or handwritten content from images for | |
search, analysis, or storage.""", | |
func=extract_text_from_image, | |
), | |
Tool( | |
name="read_file", | |
description="""Reads the raw content of a local text file. | |
Supports formats such as .txt, .json, .xml, and markdown. | |
Use this to load unstructured or semi-structured file | |
content for display, parsing, or further | |
processing—excluding CSV and Excel formats.""", | |
func=read_file, | |
), | |
Tool( | |
name="review_youtube_video", | |
description="""Analyzes a YouTube video by extracting key | |
information such as title, description, view count, likes, | |
comments, and transcript (if available). Use this to | |
generate summaries, insights, or sentiment analysis based | |
on video content and engagement.""", | |
func=review_youtube_video, | |
), | |
Tool( | |
name="transcribe_audio", | |
description="""Converts spoken words in an audio file into | |
written text using speech-to-text technology. Supports | |
common audio formats like MP3, WAV, and FLAC. Use this to | |
create transcripts for meetings, interviews, podcasts, or | |
any spoken content.""", | |
func=transcribe_audio, | |
), | |
Tool( | |
name="transcribe_youtube", | |
description="""Extracts and converts the audio from a | |
YouTube video into text using speech-to-text technology. | |
Supports generating transcripts for videos without captions | |
or subtitles. Use this to obtain searchable, readable text | |
from YouTube content.""", | |
func=transcribe_youtube, | |
), | |
Tool( | |
name="use_vision_model", | |
description="""Processes images using a computer vision | |
model to perform tasks such as object detection, image | |
classification, or segmentation. Use this to analyze visual | |
content and extract meaningful information from images.""", | |
func=use_vision_model, | |
), | |
Tool( | |
name="video_frames_to_images", | |
description="""Extracts individual frames from a video file | |
and saves them as separate image files. Use this to | |
analyze, process, or visualize specific moments within | |
video content. Use this to Youtube Videos""", | |
func=video_frames_to_images, | |
), | |
Tool( | |
name="website_scrape", | |
description="""It is mandatory to use duckduckgo_search | |
tool before invoking this tool .Fetches and extracts | |
content from a specified website URL. Supports retrieving | |
text, images, links, and other page elements.""", | |
func=website_scrape, | |
), | |
Tool( | |
name="python_repl", | |
description="""Write full, valid Python code using proper | |
multi-line code blocks Do not escape newlines (\n) | |
instead, write each line of code on a separate line Always | |
use proper indentation and syntax Return results using | |
print() or return if using a function Avoid partial or | |
inline code snippets — all code should be runnable in a | |
Python REPL If the input is a function, include example | |
usage at the end to ensure output is shown.""", | |
func=PythonREPL().run, | |
return_direct=True, | |
), | |
# Tool( | |
# name="wiki", | |
# description="""Retrieves summarized information or | |
# detailed content from Wikipedia based on a user query. | |
# Use this to quickly access encyclopedic knowledge and | |
# relevant facts on a wide range of topics.""", | |
# func=wiki, | |
# ), | |
Tool( | |
name="reverse decoder", | |
description="""Decodes a reversed sentence if the input | |
appears to be written backward.""", | |
func=reverse_decoder, | |
), | |
] | |
# tools = [wrap_tool_with_limit(tool, max_calls=3) for tool in raw_tools] | |
logger.info("Tools: %s", tools) | |
# Create the agent | |
agent = create_tool_calling_agent(llm, tools, prompt) | |
logger.info("Created tool calling agent") | |
# Create the agent executor | |
self.agent_executor = AgentExecutor( | |
agent=agent, | |
tools=tools, | |
return_intermediate_steps=True, | |
verbose=True, | |
max_iterations=5, | |
) | |
logger.info("Created agent executor") | |
# Create the graph | |
self.validation_graph = create_final_answer_graph() | |
except Exception as e: | |
logger.error("Error initializing agent: %s", e, exc_info=True) | |
raise | |
def __call__(self, question: str, task_id: str) -> str: | |
"""Execute the agent with the given question and optional file. | |
Args: | |
question (str): The question to answer | |
task_id (str): The task ID to fetch the file | |
Returns: | |
str: The final validated answer | |
Raises: | |
Exception: If no valid answer is found after max retries | |
""" | |
max_retries = 3 | |
attempt = 0 | |
previous_steps = set() | |
with tempfile.TemporaryDirectory() as temp_dir: | |
while attempt < max_retries: | |
default_api_url = os.getenv("DEFAULT_API_URL") | |
file_url = f"{default_api_url}/files/{task_id}" | |
file: Optional[dict] = None | |
try: | |
# Download file to temporary directory | |
file = download_file_from_url.invoke( | |
{ | |
"url": file_url, | |
"directory": temp_dir, | |
} | |
) | |
logger.info("Downloaded file: %s", file_url) | |
except Exception: | |
logger.error(f"no download file available for {task_id} ") | |
file = None | |
try: | |
attempt += 1 | |
logger.info("Attempt %d of %d", attempt, max_retries) | |
# Prepare input with file information | |
input_data = { | |
"input": question | |
+ ( | |
f" [File: type={file.get('type', 'None')}, path={file.get('path', 'None')}]" | |
if file and file.get("type") != "error" | |
else "" | |
), | |
} | |
# Run the agent to get the answer | |
result = self.agent_executor.invoke(input_data) | |
answer = result.get("output", "") | |
intermediate_steps = result.get("intermediate_steps", []) | |
steps_str = str(intermediate_steps) | |
if steps_str in previous_steps: | |
logger.warning( | |
f"Detected repeated reasoning steps on attempt {attempt}. Breaking loop to avoid infinite retry." | |
) | |
break # or raise Exception to stop retries | |
previous_steps.add(steps_str) | |
logger.info("Attempt %d result: %s", attempt, result) | |
# Run validation (self.validation_graph is now StateGraph) | |
validation_result = validate_answer( | |
self.validation_graph, # type: ignore | |
answer, | |
[result.get("intermediate_steps", [])], | |
) | |
valid_answer = validation_result.get("valid_answer", False) | |
final_answer = validation_result.get("final_answer", "") | |
if valid_answer: | |
logger.info("Valid answer found on attempt %d", attempt) | |
torch.cuda.empty_cache() | |
return final_answer | |
logger.warning( | |
"Validation failed on attempt %d: %s", attempt, final_answer | |
) | |
if attempt >= max_retries: | |
raise Exception( | |
"Failed to get valid answer after %d attempts. Last error: %s", | |
max_retries, | |
final_answer, | |
) | |
except Exception as e: | |
logger.error("Error in attempt %d: %s", attempt, e, exc_info=True) | |
if attempt >= max_retries: | |
raise Exception( | |
"Failed after %d attempts. Last error: %s", | |
max_retries, | |
str(e), | |
) | |
continue | |
# Fallback in case loop exits unexpectedly | |
torch.cuda.empty_cache() | |
gc.collect() | |
raise Exception("No valid answer found after processing") | |