Eduardo Guerra commited on
Commit
8d57271
·
1 Parent(s): eed0f02

feat: Final agent submission

Browse files
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  .env
2
  *__pycache__*
 
 
 
1
  .env
2
  *__pycache__*
3
+ *.DS_Store
4
+ *egg-info*
app.py CHANGED
@@ -4,6 +4,8 @@ import logging
4
  import os
5
  import sys
6
  import traceback
 
 
7
 
8
  import gradio as gr
9
  import pandas as pd
@@ -15,9 +17,6 @@ from src.agent import BasicAgent
15
  # Load environment variables from .env file
16
  load_dotenv()
17
 
18
- # Set OpenAI API key from environment variable
19
- os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
20
-
21
 
22
  # Configure logging
23
  logging.basicConfig(
@@ -29,7 +28,6 @@ logger = logging.getLogger(__name__)
29
 
30
  # (Keep Constants as is)
31
  # --- Constants ---
32
- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
33
 
34
 
35
  def run_and_submit_all(profile: gr.OAuthProfile | None):
@@ -50,7 +48,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
50
  logger.warning("User not logged in.")
51
  return "Please Login to Hugging Face with the button.", None
52
 
53
- api_url = DEFAULT_API_URL
54
  questions_url = f"{api_url}/questions"
55
  submit_url = f"{api_url}/submit"
56
 
@@ -112,8 +110,26 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
112
  logger.info(f"Running agent on {len(questions_data)} questions...")
113
 
114
  # Limit the number of questions to process to avoid timeouts
115
- max_questions = 3 # Process only 3 questions at a time
116
- questions_to_process = questions_data[:max_questions]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  logger.info(
118
  f"Processing {len(questions_to_process)} out of {len(questions_data)} questions"
119
  )
@@ -131,37 +147,52 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
131
 
132
  # Use concurrent.futures for thread-safe timeout
133
  with concurrent.futures.ThreadPoolExecutor() as executor:
134
- future = executor.submit(agent, question_text)
135
  try:
136
- submitted_answer = future.result(
137
- timeout=60
138
- ) # 60 second timeout
139
- logger.info(
140
- f"Answer for task {task_id}: {submitted_answer}"
141
- )
142
-
143
- answers_payload.append(
144
- {
145
- "task_id": task_id,
146
- "submitted_answer": submitted_answer,
147
- }
148
- )
149
- results_log.append(
150
- {
151
- "Task ID": task_id,
152
- "Question": question_text,
153
- "Submitted Answer": submitted_answer,
154
- }
155
- )
156
- except concurrent.futures.TimeoutError:
157
- logger.error(f"Timeout processing task {task_id}")
158
- results_log.append(
159
- {
160
- "Task ID": task_id,
161
- "Question": question_text,
162
- "Submitted Answer": "TIMEOUT ERROR: Question processing timed out after 60 seconds",
163
- }
164
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  except Exception as e:
166
  logger.error(
167
  f"Error running agent on task {task_id}: {e}",
 
4
  import os
5
  import sys
6
  import traceback
7
+ import tempfile
8
+ import json
9
 
10
  import gradio as gr
11
  import pandas as pd
 
17
  # Load environment variables from .env file
18
  load_dotenv()
19
 
 
 
 
20
 
21
  # Configure logging
22
  logging.basicConfig(
 
28
 
29
  # (Keep Constants as is)
30
  # --- Constants ---
 
31
 
32
 
33
  def run_and_submit_all(profile: gr.OAuthProfile | None):
 
48
  logger.warning("User not logged in.")
49
  return "Please Login to Hugging Face with the button.", None
50
 
51
+ api_url = os.getenv("DEFAULT_API_URL")
52
  questions_url = f"{api_url}/questions"
53
  submit_url = f"{api_url}/submit"
54
 
 
110
  logger.info(f"Running agent on {len(questions_data)} questions...")
111
 
112
  # Limit the number of questions to process to avoid timeouts
113
+ max_questions = 20 # Process only 20 questions at a time
114
+
115
+ tasks_to_process = [
116
+ # "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3",
117
+ # "1f975693-876d-457b-a649-393859e79bf3",
118
+ # "840bfca7-4f7b-481a-8794-c560c340185d",
119
+ # "7bd855d8-463d-4ed5-93ca-5fe35145f733",
120
+ ]
121
+
122
+ # questions_to_process = questions_data[:max_questions]
123
+
124
+ if tasks_to_process:
125
+ questions_to_process = [
126
+ x
127
+ for x in questions_data
128
+ if x.get("task_id") in tasks_to_process
129
+ ]
130
+ else:
131
+ questions_to_process = questions_data[:max_questions]
132
+
133
  logger.info(
134
  f"Processing {len(questions_to_process)} out of {len(questions_data)} questions"
135
  )
 
147
 
148
  # Use concurrent.futures for thread-safe timeout
149
  with concurrent.futures.ThreadPoolExecutor() as executor:
150
+
151
  try:
152
+ future = executor.submit(agent, question_text, task_id)
153
+ try:
154
+ submitted_answer = future.result(
155
+ timeout=180
156
+ ) # 60 second timeout
157
+ logger.info(
158
+ f"Answer for task {task_id}: {submitted_answer}"
159
+ )
160
+
161
+ answers_payload.append(
162
+ {
163
+ "task_id": task_id,
164
+ "submitted_answer": submitted_answer,
165
+ }
166
+ )
167
+ results_log.append(
168
+ {
169
+ "Task ID": task_id,
170
+ "Question": question_text,
171
+ "Submitted Answer": submitted_answer,
172
+ }
173
+ )
174
+ except concurrent.futures.TimeoutError:
175
+ logger.error(f"Timeout processing task {task_id}")
176
+ results_log.append(
177
+ {
178
+ "Task ID": task_id,
179
+ "Question": question_text,
180
+ "Submitted Answer": "TIMEOUT ERROR: Question processing timed out after 60 seconds",
181
+ }
182
+ )
183
+ finally:
184
+ # Clean up temporary directory after processing
185
+ try:
186
+ import shutil
187
+
188
+ shutil.rmtree(temp_dir)
189
+ logger.info(
190
+ f"Cleaned up temporary directory for task {task_id}"
191
+ )
192
+ except Exception as e:
193
+ logger.error(
194
+ f"Error cleaning up temporary directory for task {task_id}: {e}"
195
+ )
196
  except Exception as e:
197
  logger.error(
198
  f"Error running agent on task {task_id}: {e}",
execute_script.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google import genai
2
+ import os
3
+
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ if __name__ == "__main__":
9
+ try:
10
+ client = genai.Client(api_key=os.getenv("GEMINI_KEY"))
11
+ models = client.models.list()
12
+
13
+ result = "Available Gemini Models:\n\n"
14
+ for model in models:
15
+ result += f"Model: {model.name}\n"
16
+ result += f"Description: {model.description}\n"
17
+ result += "-" * 50 + "\n"
18
+
19
+ print(result)
20
+ except Exception as e:
21
+ print(f"Error listing models: {str(e)}")
requirements.txt CHANGED
@@ -1,16 +1,30 @@
1
  beautifulsoup4==4.13.4
2
  duckduckgo-search==8.0.1
 
 
 
3
  gradio
 
 
 
4
  langchain-core==0.3.56
5
  langchain-community==0.3.23
6
- langchain-huggingface==0.1.2
7
- langchain-openai==0.3.14
 
8
  langgraph==0.3.34
9
  lxml==5.4.0
10
  nest-asyncio==1.6.0
 
11
  playwright==1.51.0
 
 
12
  python-dotenv==1.1.0
13
  requests
 
14
  sentencepiece==0.2.0
15
  torch==2.7.0
16
  transformers==4.51.3
 
 
 
 
1
  beautifulsoup4==4.13.4
2
  duckduckgo-search==8.0.1
3
+ google-ai-generativelanguage==0.6.15
4
+ google-genai==1.13.0
5
+ google-generativeai==0.8.5
6
  gradio
7
+ imageio
8
+ imageio[ffmpeg]
9
+ imageio[pyav]
10
  langchain-core==0.3.56
11
  langchain-community==0.3.23
12
+ langchain-experimental==0.3.4
13
+ langchain-google-genai==2.0.10
14
+ langchain-google-community==2.0.7
15
  langgraph==0.3.34
16
  lxml==5.4.0
17
  nest-asyncio==1.6.0
18
+ Pillow
19
  playwright==1.51.0
20
+ pytesseract
21
+ pytest==8.3.5
22
  python-dotenv==1.1.0
23
  requests
24
+ rizaio==0.11.0
25
  sentencepiece==0.2.0
26
  torch==2.7.0
27
  transformers==4.51.3
28
+ typing-extensions==4.13.2
29
+ youtube-transcript-api==1.0.3
30
+ yt-dlp==2025.4.30
setup.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+
4
+ def read_requirements():
5
+ with open("requirements.txt") as f:
6
+ return [
7
+ line.strip()
8
+ for line in f
9
+ if line.strip() and not line.startswith("#")
10
+ ]
11
+
12
+
13
+ setup(
14
+ name="src",
15
+ version="0.1",
16
+ packages=find_packages(),
17
+ install_requires=read_requirements(),
18
+ python_requires=">=3.8",
19
+ )
src/agent.py CHANGED
@@ -1,21 +1,31 @@
1
  import logging
2
  import os
 
 
3
 
4
- # This import is required only for jupyter notebooks, since they have their own eventloop
5
- import nest_asyncio
6
- from langchain.agents import AgentExecutor, create_tool_calling_agent, tool
7
- from langchain_community.agent_toolkits import PlayWrightBrowserToolkit
8
- from langchain_community.tools import DuckDuckGoSearchResults
9
- from langchain_community.tools.playwright.utils import (
10
- create_async_playwright_browser, # A synchronous browser is available, though it isn't compatible with jupyter.\n", },
11
- )
12
- from langchain_core.messages import AIMessage, HumanMessage
13
  from langchain_core.prompts import ChatPromptTemplate
14
- from langchain_openai import ChatOpenAI
 
 
15
 
16
- from src.tools.web_scrapper import web_scrapper_tool
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- nest_asyncio.apply()
19
  logger = logging.getLogger(__name__)
20
 
21
 
@@ -24,11 +34,13 @@ class BasicAgent:
24
  try:
25
  logger.info("Initializing BasicAgent")
26
 
 
27
  prompt = ChatPromptTemplate.from_messages(
28
  [
29
  (
30
  "system",
31
- "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise, additionally, only use numbers, don't add any units and don't use any other characters. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.",
 
32
  ),
33
  ("placeholder", "{chat_history}"),
34
  ("human", "{input}"),
@@ -37,61 +49,140 @@ class BasicAgent:
37
  )
38
  logger.info("Created prompt template")
39
 
40
- # Log environment variables
41
- openai_api_key = os.getenv("OPENAI_API_KEY")
42
- logger.info(f"OPENAI_API_KEY exists: {openai_api_key is not None}")
43
-
44
- # Create OpenAI model
45
- logger.info("Creating OpenAI model...")
46
- llm = ChatOpenAI(
47
- model="gpt-3.5-turbo",
48
- openai_api_key=openai_api_key,
49
- temperature=0.7,
50
- max_tokens=1024,
51
  )
52
- logger.info("Created OpenAI model successfully")
53
-
54
- # async_browser = create_async_playwright_browser()
55
- # toolkit = PlayWrightBrowserToolkit.from_browser(
56
- # async_browser=async_browser
57
- # )
58
- # tools = toolkit.get_tools()
59
 
60
- tools = [DuckDuckGoSearchResults(), web_scrapper_tool()]
61
- logger.info(f"Tools: {tools}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
63
  agent = create_tool_calling_agent(llm, tools, prompt)
64
  logger.info("Created tool calling agent")
65
 
 
66
  self.agent_executor = AgentExecutor(
67
- agent=agent, tools=tools, verbose=True
 
 
 
68
  )
69
  logger.info("Created agent executor")
70
 
 
 
 
71
  except Exception as e:
72
- logger.error(f"Error initializing agent: {e}", exc_info=True)
73
  raise
74
 
75
- def __call__(self, question: str) -> str:
76
- try:
77
- logger.info(f"Processing question: {question}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- retries = 3
80
- while retries > 0:
81
  try:
82
- response = self.agent_executor.invoke({"input": question})[
83
- "output"
84
- ]
85
- response = response.split("FINAL ANSWER:")[1].strip()
86
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
  logger.error(
89
- f"Error processing question: {e}", exc_info=True
90
  )
91
- response = "Could not process question"
92
- retries -= 1
93
- logger.info(f"Response: {response}")
94
- return response
95
- except Exception as e:
96
- logger.error(f"Error processing question: {e}", exc_info=True)
97
- raise
 
1
  import logging
2
  import os
3
+ from typing import Optional, Dict
4
+ import tempfile
5
 
6
+ from langchain.agents import AgentExecutor, create_tool_calling_agent
7
+ from langchain_google_community import GoogleSearchResults
8
+ from langchain_google_community import GoogleSearchAPIWrapper
 
 
 
 
 
 
9
  from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.tools import Tool
11
+ from langchain_google_genai import ChatGoogleGenerativeAI
12
+ from langchain_experimental.utilities import PythonREPL
13
 
14
+ from src.final_answer import create_final_answer_graph, validate_answer
15
+ from src.tools import (
16
+ analyze_csv_file,
17
+ analyze_excel_file,
18
+ download_file_from_url,
19
+ extract_text_from_image,
20
+ read_file,
21
+ review_youtube_video,
22
+ transcribe_audio,
23
+ transcribe_youtube,
24
+ use_vision_model,
25
+ video_frames_to_images,
26
+ website_scrape,
27
+ )
28
 
 
29
  logger = logging.getLogger(__name__)
30
 
31
 
 
34
  try:
35
  logger.info("Initializing BasicAgent")
36
 
37
+ # Create the prompt template
38
  prompt = ChatPromptTemplate.from_messages(
39
  [
40
  (
41
  "system",
42
+ """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
43
+ """,
44
  ),
45
  ("placeholder", "{chat_history}"),
46
  ("human", "{input}"),
 
49
  )
50
  logger.info("Created prompt template")
51
 
52
+ # Initialize Gemini model
53
+ logger.info("Creating Gemini model...")
54
+ llm = ChatGoogleGenerativeAI(
55
+ model="models/gemini-2.5-pro-preview-03-25",
56
+ google_api_key=os.getenv("GEMINI_KEY"),
57
+ temperature=0.2,
 
 
 
 
 
58
  )
59
+ logger.info("Created Gemini model successfully")
 
 
 
 
 
 
60
 
61
+ # Define available tools
62
+ tools = [
63
+ GoogleSearchResults(
64
+ api_wrapper=GoogleSearchAPIWrapper(
65
+ google_api_key=os.getenv("GOOGLE_SEARCH_API_KEY"),
66
+ google_cse_id=os.getenv("GOOGLE_CSE_ID"),
67
+ k=5, # Number of results to return
68
+ )
69
+ ),
70
+ analyze_csv_file,
71
+ analyze_excel_file,
72
+ download_file_from_url,
73
+ extract_text_from_image,
74
+ read_file,
75
+ review_youtube_video,
76
+ transcribe_audio,
77
+ transcribe_youtube,
78
+ use_vision_model,
79
+ video_frames_to_images,
80
+ website_scrape,
81
+ Tool(
82
+ name="python_repl",
83
+ description="A Python shell. Use this to execute python commands. Input # should be a valid python command. If you want to see the output of a value, # you should print it out with `print(...)`.",
84
+ func=PythonREPL().run,
85
+ ),
86
+ ]
87
+ logger.info("Tools: %s", tools)
88
 
89
+ # Create the agent
90
  agent = create_tool_calling_agent(llm, tools, prompt)
91
  logger.info("Created tool calling agent")
92
 
93
+ # Create the agent executor
94
  self.agent_executor = AgentExecutor(
95
+ agent=agent,
96
+ tools=tools,
97
+ return_intermediate_steps=True,
98
+ verbose=True,
99
  )
100
  logger.info("Created agent executor")
101
 
102
+ # Create the graph
103
+ self.validation_graph = create_final_answer_graph()
104
+
105
  except Exception as e:
106
+ logger.error("Error initializing agent: %s", e, exc_info=True)
107
  raise
108
 
109
+ def __call__(self, question: str, task_id: str) -> str:
110
+ """Execute the agent with the given question and optional file.
111
+
112
+ Args:
113
+ question (str): The question to answer
114
+ task_id (str): The task ID to fetch the file
115
+ """
116
+ max_retries = 3
117
+ attempt = 0
118
+
119
+ # Create a temporary directory that will be automatically cleaned up
120
+ with tempfile.TemporaryDirectory() as temp_dir:
121
+ while attempt < max_retries:
122
+ default_api_url = os.getenv("DEFAULT_API_URL")
123
+ file_url = f"{default_api_url}/files/{task_id}"
124
+
125
+ try:
126
+ # Download file to temporary directory
127
+ file = download_file_from_url.invoke(
128
+ {
129
+ "url": file_url,
130
+ "directory": temp_dir,
131
+ }
132
+ )
133
+ except Exception as e:
134
+ logger.error(f"Error downloading file: {e}")
135
+ file = None
136
 
 
 
137
  try:
138
+ attempt += 1
139
+ logger.info(f"Attempt {attempt} of {max_retries}")
140
+
141
+ # Prepare input with file information
142
+ if file and file.get("type") != "error":
143
+ input_data = {
144
+ "input": question
145
+ + f" [File: type={file.get('type', 'None')}, path={file.get('path', 'None')}]",
146
+ }
147
+ else:
148
+ input_data = {
149
+ "input": question,
150
+ }
151
+
152
+ # Run the agent to get the answer
153
+ result = self.agent_executor.invoke(input_data)
154
+ answer = result.get("output", "")
155
+
156
+ logger.info(f"Attempt {attempt} result: {result}")
157
+
158
+ # Run validation
159
+ validation_result = validate_answer(
160
+ self.validation_graph,
161
+ answer,
162
+ [result.get("intermediate_steps", [])],
163
+ )
164
+
165
+ valid_answer = validation_result.get("valid_answer", False)
166
+ final_answer = validation_result.get("final_answer", "")
167
+
168
+ if valid_answer:
169
+ logger.info(f"Valid answer found on attempt {attempt}")
170
+ return final_answer
171
+
172
+ logger.warning(
173
+ f"Validation failed on attempt {attempt}: {final_answer}"
174
+ )
175
+ if attempt >= max_retries:
176
+ raise Exception(
177
+ f"Failed to get valid answer after {max_retries} attempts. Last error: {final_answer}"
178
+ )
179
+
180
  except Exception as e:
181
  logger.error(
182
+ f"Error in attempt {attempt}: {e}", exc_info=True
183
  )
184
+ if attempt >= max_retries:
185
+ raise Exception(
186
+ f"Failed after {max_retries} attempts. Last error: {str(e)}"
187
+ )
188
+ continue
 
 
src/final_answer.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from typing import Any, Dict, Optional
4
+ from typing_extensions import TypedDict
5
+
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.prompts import ChatPromptTemplate
8
+ from langgraph.graph import Graph, StateGraph, START, END
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+
11
+
12
+ class AgentState(TypedDict):
13
+ """State for the final answer validation graph."""
14
+
15
+ question: str
16
+ answer: str
17
+ final_answer: str | None
18
+ agent_memory: Any
19
+ valid_answer: bool
20
+
21
+
22
+ def extract_answer(state: AgentState) -> Dict:
23
+ """Extract and format the final answer from the state.
24
+
25
+ Args:
26
+ state: The state of the agent.
27
+
28
+ Returns:
29
+ A dictionary with the formatted final answer.
30
+ """
31
+ # Extract the final answer from the state
32
+ sep_token = "FINAL ANSWER:"
33
+ raw_answer = state["answer"]
34
+
35
+ # Extract the answer after the separator if it exists
36
+ if sep_token in raw_answer:
37
+ formatted_answer = raw_answer.split(sep_token)[1].strip()
38
+ else:
39
+ formatted_answer = raw_answer.strip()
40
+
41
+ # Remove any brackets from lists
42
+ formatted_answer = formatted_answer.replace("[", "").replace("]", "")
43
+
44
+ # Remove units unless specified
45
+ if not any(
46
+ unit in formatted_answer.lower()
47
+ for unit in ["$", "%", "dollars", "percent"]
48
+ ):
49
+ formatted_answer = formatted_answer.replace("$", "").replace("%", "")
50
+
51
+ # Remove commas from numbers
52
+ parts = formatted_answer.split(",")
53
+ formatted_parts = []
54
+ for part in parts:
55
+ part = part.strip()
56
+ if part.replace(".", "").isdigit(): # Check if it's a number
57
+ part = part.replace(",", "")
58
+ formatted_parts.append(part)
59
+ formatted_answer = ", ".join(formatted_parts)
60
+
61
+ return {"final_answer": formatted_answer}
62
+
63
+
64
+ def reasoning_check(state: AgentState) -> Dict:
65
+ """
66
+ Node that checks the reasoning of the final answer.
67
+
68
+ Args:
69
+ state: The state of the agent.
70
+
71
+ Returns:
72
+ A dictionary with the reasoning check result.
73
+ """
74
+ model = ChatGoogleGenerativeAI(
75
+ model="models/gemini-2.0-flash-lite",
76
+ google_api_key=os.getenv("GEMINI_KEY"),
77
+ temperature=0.2,
78
+ )
79
+ prompt = ChatPromptTemplate.from_messages(
80
+ [
81
+ (
82
+ "system",
83
+ """You are a strict validator of answers. Your job is to check if the reasoning and results are correct.
84
+ You should have >90% confidence that the answer is correct to pass it.
85
+ First list reasons why yes/no, then write your final decision: PASS in caps lock if it is satisfactory, FAIL if it is not.""",
86
+ ),
87
+ (
88
+ "human",
89
+ """
90
+ Here is a user-given task and the agent steps: {agent_memory}
91
+ Now here is the answer that was given: {final_answer}
92
+ Please check that the reasoning process and results are correct: do they correctly answer the given task?
93
+ """,
94
+ ),
95
+ ]
96
+ )
97
+
98
+ chain = prompt | model | StrOutputParser()
99
+ output = chain.invoke(
100
+ {
101
+ "agent_memory": state["agent_memory"],
102
+ "final_answer": state["final_answer"],
103
+ }
104
+ )
105
+
106
+ print("Reasoning Feedback: ", output)
107
+ if "FAIL" in output:
108
+ return {"valid_answer": False}
109
+ return {"valid_answer": True}
110
+
111
+
112
+ def formatting_check(state: AgentState) -> Dict:
113
+ """
114
+ Node that checks the formatting of the final answer.
115
+
116
+ Args:
117
+ state: The state of the agent.
118
+
119
+ Returns:
120
+ A dictionary with the formatting check result.
121
+ """
122
+ model = ChatGoogleGenerativeAI(
123
+ model="models/gemini-2.0-flash-lite",
124
+ google_api_key=os.getenv("GEMINI_KEY"),
125
+ temperature=0.2,
126
+ )
127
+ prompt = ChatPromptTemplate.from_messages(
128
+ [
129
+ (
130
+ "system",
131
+ """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
132
+ """,
133
+ ),
134
+ (
135
+ "human",
136
+ """
137
+ Here is a user-given task and the agent steps: {agent_memory}
138
+ Now here is the FINAL ANSWER that was given: {final_answer}
139
+ Ensure the FINAL ANSWER is in the right format as asked for by the task.
140
+ """,
141
+ ),
142
+ ]
143
+ )
144
+
145
+ chain = prompt | model | StrOutputParser()
146
+ output = chain.invoke(
147
+ {
148
+ "agent_memory": state["agent_memory"],
149
+ "final_answer": state["final_answer"],
150
+ }
151
+ )
152
+
153
+ print("Formatting Feedback: ", output)
154
+ if "FAIL" in output:
155
+ return {"valid_answer": False}
156
+ return {"valid_answer": True}
157
+
158
+
159
+ def create_final_answer_graph() -> Graph:
160
+ """Create a graph that validates the final answer.
161
+
162
+ Returns:
163
+ A graph that validates the final answer.
164
+ """
165
+ # Create the graph
166
+ workflow = StateGraph(AgentState)
167
+
168
+ # Add nodes
169
+ workflow.add_node("extract_answer", extract_answer)
170
+ workflow.add_node("reasoning_check", reasoning_check)
171
+ workflow.add_node("formatting_check", formatting_check)
172
+
173
+ # Add edges
174
+ workflow.add_edge(START, "extract_answer")
175
+ workflow.add_edge("extract_answer", "reasoning_check")
176
+ workflow.add_edge("reasoning_check", "formatting_check")
177
+ workflow.add_edge("formatting_check", END)
178
+
179
+ # Compile the graph
180
+ return workflow.compile()
181
+
182
+
183
+ def validate_answer(graph: Graph, answer: str, agent_memory: Any) -> Dict:
184
+ """Validate the answer using the LangGraph workflow.
185
+
186
+ Args:
187
+ graph: The validation graph.
188
+ answer: The answer to validate.
189
+ agent_memory: The agent's memory.
190
+
191
+ Returns:
192
+ A dictionary with validation results.
193
+ """
194
+ try:
195
+ # Initialize state
196
+ initial_state = {
197
+ "answer": answer,
198
+ "final_answer": None,
199
+ "agent_memory": agent_memory,
200
+ "valid_answer": False,
201
+ }
202
+
203
+ # Run the graph
204
+ result = graph.invoke(initial_state)
205
+
206
+ return {
207
+ "valid_answer": result.get("valid_answer", False),
208
+ "final_answer": result.get("final_answer", None),
209
+ }
210
+ except Exception as e:
211
+ print(f"Validation failed: {e}")
212
+ return {"valid_answer": False, "final_answer": None}
src/tools.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import shutil
3
+ import os
4
+ import tempfile
5
+ import uuid
6
+ from typing import List, Optional, Dict, Union
7
+ import re
8
+ import time
9
+ from datetime import datetime, timedelta
10
+ from bs4 import BeautifulSoup
11
+ from playwright.sync_api import sync_playwright
12
+
13
+ import imageio
14
+ import pandas as pd
15
+ import pytesseract
16
+ import requests
17
+ import yt_dlp
18
+ from dotenv import load_dotenv
19
+ from google import genai
20
+ from google.genai import types
21
+ from langchain_core.tools import tool
22
+ from PIL import Image
23
+ from youtube_transcript_api import YouTubeTranscriptApi
24
+
25
+ load_dotenv()
26
+
27
+
28
+ # Vision Model Tool
29
+ @tool
30
+ def use_vision_model(
31
+ question: str, image_paths: List[str], mime_type: str
32
+ ) -> str:
33
+ """Use a Vision Model to answer a question about a set of images.
34
+
35
+ Args:
36
+ question (str): The question you are asking about the images.
37
+ image_paths (List[str]): The paths to the images to use for the question.
38
+ mime_type (str): The mime type of the image.
39
+
40
+ Returns:
41
+ str: The answer to the question
42
+ """
43
+ try:
44
+ client = genai.Client(api_key=os.getenv("GEMINI_KEY"))
45
+ model = "models/gemini-2.0-flash-001"
46
+
47
+ # Prepare the content parts
48
+ parts = []
49
+ for image_path in image_paths:
50
+ with open(image_path, "rb") as f:
51
+ image_bytes = f.read()
52
+
53
+ response = []
54
+
55
+ for chunk in client.models.generate_content_stream(
56
+ model=model,
57
+ contents=[
58
+ question,
59
+ types.Part.from_bytes(data=image_bytes, mime_type=mime_type),
60
+ ],
61
+ ):
62
+ response.append(chunk.text)
63
+
64
+ return " ".join(response)
65
+
66
+ except Exception as e:
67
+ return f"Error using vision model: {str(e)}"
68
+
69
+
70
+ # YouTube Video Review Tool
71
+ @tool
72
+ def review_youtube_video(url: str, question: str) -> str:
73
+ """Reviews a YouTube video and answers a specific question about that video.
74
+
75
+ Args:
76
+ url (str): the URL to the YouTube video.
77
+ question (str): The question you are asking about the video
78
+
79
+ Returns:
80
+ str: The answer to the question
81
+ """
82
+ try:
83
+ client = genai.Client(api_key=os.getenv("GEMINI_KEY"))
84
+ model = "models/gemini-1.5-flash-8b"
85
+
86
+ response = client.models.generate_content(
87
+ model=model,
88
+ contents=types.Content(
89
+ parts=[
90
+ types.Part(file_data=types.FileData(file_uri=url)),
91
+ types.Part(text=question),
92
+ ]
93
+ ),
94
+ )
95
+ return response.text
96
+ except Exception as e:
97
+ return f"Error asking {model} about video: {str(e)}"
98
+
99
+
100
+ # YouTube Frames to Images Tool
101
+ @tool
102
+ def video_frames_to_images(
103
+ url: str,
104
+ folder_name: str,
105
+ sample_interval_seconds: int = 5,
106
+ ) -> List[str]:
107
+ """Extracts frames from a video at specified intervals and saves them as images.
108
+
109
+ Args:
110
+ url (str): the URL to the video.
111
+ folder_name (str): the name of the folder to save the images to.
112
+ sample_interval_seconds (int): the interval between frames to sample.
113
+
114
+ Returns:
115
+ List[str]: A list of paths to the saved image files.
116
+ """
117
+ # Create a subdirectory for the frames
118
+ frames_dir = os.path.join(folder_name, "frames")
119
+ os.makedirs(frames_dir, exist_ok=True)
120
+
121
+ ydl_opts = {
122
+ "format": "bestvideo[height<=1080]+bestaudio/best[height<=1080]/best",
123
+ "outtmpl": os.path.join(folder_name, "video.%(ext)s"),
124
+ "quiet": True,
125
+ "noplaylist": True,
126
+ "merge_output_format": "mp4",
127
+ "force_ipv4": True,
128
+ }
129
+
130
+ try:
131
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
132
+ info = ydl.extract_info(url, download=True)
133
+ video_path = next(
134
+ (
135
+ os.path.join(folder_name, f)
136
+ for f in os.listdir(folder_name)
137
+ if f.endswith(".mp4")
138
+ ),
139
+ None,
140
+ )
141
+
142
+ if not video_path:
143
+ raise RuntimeError("Failed to download video as mp4")
144
+
145
+ reader = imageio.get_reader(video_path)
146
+ metadata = reader.get_meta_data()
147
+ fps = metadata.get("fps")
148
+
149
+ if fps is None:
150
+ reader.close()
151
+ raise RuntimeError(
152
+ "Unable to determine FPS from video metadata"
153
+ )
154
+
155
+ frame_interval = int(fps * sample_interval_seconds)
156
+ image_paths: List[str] = []
157
+
158
+ for idx, frame in enumerate(reader):
159
+ if idx % frame_interval == 0:
160
+ # Save frame as image
161
+ image_path = os.path.join(
162
+ frames_dir, f"frame_{idx:06d}.jpg"
163
+ )
164
+ imageio.imwrite(image_path, frame)
165
+ image_paths.append(image_path)
166
+
167
+ reader.close()
168
+ return image_paths
169
+
170
+ except Exception as e:
171
+ raise RuntimeError(f"Error processing video frames: {str(e)}") from e
172
+
173
+
174
+ # File Reading Tool
175
+ @tool
176
+ def read_file(filepath: str) -> str:
177
+ """Reads the content of a text file.
178
+
179
+ Args:
180
+ filepath (str): the path to the file to read.
181
+
182
+ Returns:
183
+ str: The content of the file.
184
+ """
185
+ try:
186
+ with open(filepath, "r", encoding="utf-8") as file:
187
+ content = file.read()
188
+ return content
189
+ except FileNotFoundError:
190
+ return f"File not found: {filepath}"
191
+ except IOError as e:
192
+ return f"Error reading file: {str(e)}"
193
+
194
+
195
+ # File Download Tool
196
+ @tool
197
+ def download_file_from_url(
198
+ url: str, directory: str
199
+ ) -> Dict[str, Union[str, None]]:
200
+ """Downloads a file from a URL and saves it to a directory.
201
+
202
+ Args:
203
+ url (str): the URL to download the file from.
204
+ directory (str): the directory to save the file to.
205
+
206
+ Returns:
207
+ Dict[str, Union[str, None]]: A dictionary containing the file type and path.
208
+ """
209
+
210
+ try:
211
+ response = requests.get(url, stream=True, timeout=10)
212
+ response.raise_for_status()
213
+
214
+ content_type = response.headers.get("content-type", "").lower()
215
+
216
+ # Try to get filename from headers
217
+ filename = None
218
+ cd = response.headers.get("content-disposition", "")
219
+ match = re.search(r"filename\*=UTF-8\'\'(.+)", cd) or re.search(
220
+ r'filename="?([^"]+)"?', cd
221
+ )
222
+ if match:
223
+ filename = match.group(1)
224
+
225
+ # If not in headers, try URL
226
+ if not filename:
227
+ filename = os.path.basename(url.split("?")[0])
228
+
229
+ # Fallback to generated filename
230
+ if not filename:
231
+ extension = {
232
+ "image/jpeg": ".jpg",
233
+ "image/png": ".png",
234
+ "image/gif": ".gif",
235
+ "audio/wav": ".wav",
236
+ "audio/mpeg": ".mp3",
237
+ "video/mp4": ".mp4",
238
+ "text/plain": ".txt",
239
+ "text/csv": ".csv",
240
+ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
241
+ "application/vnd.ms-excel": ".xls",
242
+ "application/octet-stream": ".bin",
243
+ }.get(content_type, ".bin")
244
+ filename = f"downloaded_{uuid.uuid4().hex[:8]}{extension}"
245
+
246
+ os.makedirs(directory, exist_ok=True)
247
+ file_path = os.path.join(directory, filename)
248
+
249
+ with open(file_path, "wb") as f:
250
+ for chunk in response.iter_content(chunk_size=8192):
251
+ f.write(chunk)
252
+
253
+ # shutil.copy(file_path, os.getcwd())
254
+
255
+ if os.path.exists(file_path) and os.path.getsize(file_path) > 0:
256
+ return {"type": content_type, "path": file_path}
257
+ else:
258
+ return {
259
+ "type": "error",
260
+ "path": None,
261
+ "error": "Failed to save file",
262
+ }
263
+
264
+ except Exception as e:
265
+ return {
266
+ "type": "error",
267
+ "path": None,
268
+ "error": f"Error downloading file: {str(e)}",
269
+ }
270
+
271
+
272
+ # Text Extraction from Image Tool
273
+ @tool
274
+ def extract_text_from_image(image_path: str) -> str:
275
+ """Extracts text from an image using OCR.
276
+
277
+ Args:
278
+ image_path (str): the path to the image to extract text from.
279
+
280
+ Returns:
281
+ str: The text extracted from the image.
282
+ """
283
+ try:
284
+
285
+ image = Image.open(image_path)
286
+ text = pytesseract.image_to_string(image)
287
+ return f"Extracted text from image:\n\n{text}"
288
+ except Exception as e:
289
+ return f"Error extracting text from image: {str(e)}"
290
+
291
+
292
+ # CSV Analysis Tool
293
+ @tool
294
+ def analyze_csv_file(file_path: str, query: str) -> str:
295
+ """Analyzes a CSV file and answers questions about its contents using Gemini.
296
+
297
+ Args:
298
+ file_path (str): the path to the CSV file to analyze.
299
+ query (str): the question to answer about the CSV file.
300
+
301
+ Returns:
302
+ str: The result of the analysis.
303
+ """
304
+ try:
305
+ # Read the CSV file
306
+ df = pd.read_csv(file_path)
307
+
308
+ # Initialize Gemini
309
+ client = genai.Client(api_key=os.getenv("GEMINI_KEY"))
310
+ model = "models/gemini-1.5-flash-8b"
311
+
312
+ # Convert DataFrame to a string representation
313
+ df_str = df.to_string()
314
+
315
+ # Create a prompt for Gemini
316
+ prompt = f"""Analyze this CSV data and provide insights:
317
+
318
+ Dimensions: {len(df)} rows × {len(df.columns)} columns
319
+
320
+ Data:
321
+ {df_str}
322
+
323
+ Please provide:
324
+ 1. A summary of the data structure and content
325
+ 2. Key patterns and insights
326
+ 3. Potential data quality issues
327
+ 4. Suggestions for analysis
328
+
329
+ User Query: {query}
330
+
331
+ Please format your response in a clear, structured way with sections and bullet points."""
332
+
333
+ # Get analysis from Gemini
334
+ response = client.models.generate_content(
335
+ model=model,
336
+ contents=types.Content(
337
+ parts=[
338
+ types.Part(text=df_str),
339
+ types.Part(text=prompt),
340
+ ]
341
+ ),
342
+ )
343
+
344
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n\n"
345
+ result += response.text
346
+
347
+ return result
348
+ except Exception as e:
349
+ return f"Error analyzing CSV file: {str(e)}"
350
+
351
+
352
+ # Excel Analysis Tool
353
+ @tool
354
+ def analyze_excel_file(file_path: str, query: str) -> str:
355
+ """Analyzes an Excel file and answers questions about its contents using Gemini.
356
+
357
+ Args:
358
+ file_path (str): the path to the Excel file to analyze.
359
+ query (str): the question to answer about the Excel file.
360
+
361
+ Returns:
362
+ str: The result of the analysis.
363
+ """
364
+ try:
365
+ # Read all sheets from the Excel file
366
+ excel_file = pd.ExcelFile(file_path)
367
+ sheet_names = excel_file.sheet_names
368
+
369
+ # Initialize Gemini
370
+ client = genai.Client(api_key=os.getenv("GEMINI_KEY"))
371
+ model = "models/gemini-1.5-flash-8b"
372
+
373
+ result = f"Excel file loaded with {len(sheet_names)} sheets: {', '.join(sheet_names)}\n\n"
374
+
375
+ # Analyze each sheet
376
+ for sheet_name in sheet_names:
377
+ df = pd.read_excel(file_path, sheet_name=sheet_name)
378
+
379
+ # Convert DataFrame to a string representation
380
+ df_str = df.to_string()
381
+
382
+ # Create a prompt for Gemini
383
+ prompt = f"""Analyze this Excel sheet data and provide insights:
384
+
385
+ Sheet Name: {sheet_name}
386
+ Dimensions: {len(df)} rows × {len(df.columns)} columns
387
+
388
+ Data:
389
+ {df_str}
390
+
391
+ Please provide:
392
+ 1. A summary of the data structure and content
393
+ 2. Key patterns and insights
394
+ 3. Potential data quality issues
395
+ 4. Suggestions for analysis
396
+
397
+ User Query: {query}
398
+
399
+ Please format your response in a clear, structured way with sections and bullet points."""
400
+
401
+ # Get analysis from Gemini
402
+ response = client.models.generate_content(
403
+ model=model,
404
+ contents=types.Content(
405
+ parts=[types.Part(text=df_str), types.Part(text=prompt)]
406
+ ),
407
+ )
408
+
409
+ result += f"=== Sheet: {sheet_name} ===\n"
410
+ result += response.text + "\n"
411
+ result += "=" * 50 + "\n\n"
412
+
413
+ return result
414
+ except Exception as e:
415
+ return f"Error analyzing Excel file: {str(e)}"
416
+
417
+
418
+ # Audio Transcription Tool
419
+ @tool
420
+ def transcribe_audio(audio_file_path: str, mime_type: str) -> str:
421
+ """Transcribes an audio file using Gemini's audio capabilities.
422
+
423
+ Args:
424
+ audio_file_path (str): the path to the audio file to transcribe.
425
+ mime_type (str): the mime type of the audio file.
426
+
427
+ Returns:
428
+ str: The transcript of the audio file.
429
+ """
430
+ try:
431
+ # Initialize the model
432
+ client = genai.Client(api_key=os.getenv("GEMINI_KEY"))
433
+ model = "models/gemini-1.5-flash-8b"
434
+
435
+ # Read and encode the audio file
436
+ with open(audio_file_path, "rb") as audio_file:
437
+ audio_data = audio_file.read()
438
+
439
+ # Create the content with audio data
440
+ contents = types.Content(
441
+ parts=[
442
+ types.Part.from_bytes(
443
+ data=audio_data,
444
+ mime_type=mime_type,
445
+ ),
446
+ types.Part(text="Please transcribe this audio file."),
447
+ ]
448
+ )
449
+
450
+ # Generate transcription
451
+ response = client.models.generate_content(
452
+ model=model, contents=contents
453
+ )
454
+ return response.text
455
+ except Exception as e:
456
+ return f"Error transcribing audio: {str(e)}"
457
+
458
+
459
+ def _extract_video_id(url: str) -> Optional[str]:
460
+ """Extract video ID from YouTube URL.
461
+
462
+ Args:
463
+ url (str): the URL to the YouTube video.
464
+
465
+ Returns:
466
+ str: The video ID of the YouTube video.
467
+ """
468
+ patterns = [
469
+ r"(?:youtube\.com\/watch\?v=|youtube\.com\/embed\/|youtu\.be\/)([^&\n?#]+)",
470
+ r"(?:youtube\.com\/v\/|youtube\.com\/e\/|youtube\.com\/user\/[^\/]+\/|youtube\.com\/[^\/]+\/|youtube\.com\/embed\/|youtu\.be\/)([^&\n?#]+)",
471
+ ]
472
+
473
+ for pattern in patterns:
474
+ match = re.search(pattern, url)
475
+ if match:
476
+ return match.group(1)
477
+ return None
478
+
479
+
480
+ @tool
481
+ def transcribe_youtube(url: str) -> str:
482
+ """Transcribes a YouTube video using YouTube Transcript API or Gemini as fallback.
483
+
484
+ Args:
485
+ url (str): the URL to the YouTube video.
486
+
487
+ Returns:
488
+ str: The transcript of the YouTube video.
489
+ """
490
+ try:
491
+ # First try using YouTube Transcript API
492
+ video_id = _extract_video_id(url)
493
+ if not video_id:
494
+ raise ValueError(f"Invalid YouTube URL: {url}")
495
+
496
+ try:
497
+ # Try to get transcript in English
498
+ transcript_chunks = YouTubeTranscriptApi.get_transcript(
499
+ video_id, languages=["en"]
500
+ )
501
+ # Combine all chunks into a single transcript with timestamps
502
+ transcript = ""
503
+ for chunk in transcript_chunks:
504
+ timestamp = str(timedelta(seconds=int(chunk["start"])))
505
+ transcript += f"[{timestamp}] {chunk['text']}\n"
506
+ return transcript
507
+
508
+ except Exception as transcript_error:
509
+ print(
510
+ f"Failed to get transcript using YouTube API: {str(transcript_error)}"
511
+ )
512
+ print("Falling back to Gemini-based transcription...")
513
+
514
+ # Fallback to Gemini-based transcription
515
+ with tempfile.TemporaryDirectory() as tmpdir:
516
+ # Download audio from YouTube
517
+ ydl_opts = {
518
+ "format": "bestaudio/best",
519
+ "outtmpl": os.path.join(tmpdir, "audio.%(ext)s"),
520
+ "quiet": True,
521
+ "noplaylist": True,
522
+ "postprocessors": [
523
+ {
524
+ "key": "FFmpegExtractAudio",
525
+ "preferredcodec": "wav",
526
+ "preferredquality": "192",
527
+ }
528
+ ],
529
+ }
530
+
531
+ try:
532
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
533
+ info = ydl.extract_info(url, download=True)
534
+ audio_path = next(
535
+ (
536
+ os.path.join(tmpdir, f)
537
+ for f in os.listdir(tmpdir)
538
+ if f.endswith(".wav")
539
+ ),
540
+ None,
541
+ )
542
+
543
+ if not audio_path:
544
+ raise RuntimeError(
545
+ "Failed to download audio"
546
+ ) from transcript_error
547
+
548
+ # Use Gemini to transcribe the audio
549
+ client = genai.Client(api_key=os.getenv("GEMINI_KEY"))
550
+ model = "models/gemini-1.5-flash-8b"
551
+
552
+ # Read the audio file
553
+ with open(audio_path, "rb") as audio_file:
554
+ audio_data = audio_file.read()
555
+
556
+ # Create the content with audio data
557
+ contents = types.Content(
558
+ parts=[
559
+ types.Part(
560
+ file_data=types.FileData(
561
+ mime_type="audio/wav",
562
+ data=audio_data,
563
+ )
564
+ ),
565
+ types.Part(
566
+ text="Please transcribe this audio file. Include timestamps if possible."
567
+ ),
568
+ ]
569
+ )
570
+
571
+ # Generate transcription
572
+ response = client.models.generate_content(
573
+ model=model, contents=contents
574
+ )
575
+ return response.text
576
+
577
+ except yt_dlp.utils.DownloadError as e:
578
+ raise RuntimeError(
579
+ f"Error downloading YouTube video: {str(e)}"
580
+ ) from transcript_error
581
+ except Exception as e:
582
+ raise RuntimeError(
583
+ f"Error processing YouTube video: {str(e)}"
584
+ ) from transcript_error
585
+
586
+ except Exception as e:
587
+ raise RuntimeError(f"Error in YouTube transcription: {str(e)}") from e
588
+
589
+
590
+ @tool
591
+ def website_scrape(url: str, question: str) -> str:
592
+ """Scrapes a website and returns the text.
593
+
594
+ Args:
595
+ url (str): the URL to the website to scrape.
596
+
597
+ Returns:
598
+ str: The text of the website.
599
+ """
600
+
601
+ with sync_playwright() as p:
602
+ browser = p.chromium.launch(headless=True)
603
+ page = browser.new_page()
604
+ page.goto(url)
605
+ html_content = page.content()
606
+ browser.close()
607
+
608
+ soup = BeautifulSoup(html_content, "html.parser")
609
+
610
+ # Extract text from the website
611
+ text = soup.get_text()
612
+
613
+ return text
src/tools/__init__.py DELETED
File without changes
src/tools/image_to_text.py DELETED
File without changes
src/tools/web_scrapper.py DELETED
@@ -1,23 +0,0 @@
1
- from bs4 import BeautifulSoup
2
- from langgraph import Tool
3
- from playwright.sync_api import sync_playwright
4
-
5
-
6
- def extract_website_content(url: str) -> str:
7
- with sync_playwright() as p:
8
- browser = p.chromium.launch(headless=True)
9
- page = browser.new_page()
10
- page.goto(url)
11
- html_content = page.content()
12
- browser.close()
13
-
14
- soup = BeautifulSoup(html_content, "html.parser")
15
- return soup.get_text()
16
-
17
-
18
- def web_scrapper_tool():
19
- return Tool.from_function(
20
- func=extract_website_content,
21
- name="scrape_website",
22
- description="Extracts the main content of a webpage given its URL.",
23
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_tools.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+ from dotenv import load_dotenv
4
+ import tempfile
5
+ from src.tools import (
6
+ use_vision_model,
7
+ review_youtube_video,
8
+ video_frames_to_images,
9
+ read_file,
10
+ download_file_from_url,
11
+ extract_text_from_image,
12
+ analyze_csv_file,
13
+ analyze_excel_file,
14
+ transcribe_audio,
15
+ transcribe_youtube,
16
+ website_scrape,
17
+ )
18
+
19
+ # Load environment variables
20
+ load_dotenv()
21
+
22
+
23
+ @pytest.fixture
24
+ def test_dir(tmp_path):
25
+ """Create a temporary directory for test files."""
26
+ return tmp_path
27
+
28
+
29
+ def test_website_scrape():
30
+ """Test website scraping functionality."""
31
+ url = "https://en.wikipedia.org/wiki/2025_World_Snooker_Championship"
32
+ question = "What is the main heading?"
33
+ result = website_scrape.invoke({"url": url, "question": question})
34
+ assert isinstance(result, str)
35
+ assert len(result) > 0
36
+ print("\nWebsite Scrape Test Result:", result[:200])
37
+
38
+
39
+ def test_read_file(test_dir):
40
+ """Test file reading functionality."""
41
+ # Create a test file
42
+ test_file = test_dir / "test.txt"
43
+ test_file.write_text("Test content")
44
+
45
+ result = read_file.invoke({"filepath": str(test_file)})
46
+ assert isinstance(result, str)
47
+ assert result == "Test content"
48
+ print("\nRead File Test Result:", result)
49
+
50
+
51
+ def test_download_file_from_url():
52
+ """Test file downloading functionality."""
53
+ path = "https://fastly.picsum.photos/id/856/400/400.jpg?hmac=tb7tfZIDAlSxzTJ6V0l3sJH4CxcWXW1z4aiWrqbbQSs"
54
+
55
+ with tempfile.TemporaryDirectory() as temp_dir:
56
+ temp_file = os.path.join(temp_dir, "test.jpg")
57
+
58
+ print(f"Downloading file to: {temp_file}")
59
+
60
+ result = download_file_from_url.invoke(
61
+ {"url": path, "file_path": temp_file}
62
+ )
63
+ assert isinstance(result, str)
64
+ assert os.path.exists(temp_file)
65
+
66
+
67
+ def test_extract_text_from_image():
68
+ """Test OCR functionality."""
69
+ image_path = "test_files/text_in_image.jpg"
70
+ result = extract_text_from_image.invoke({"image_path": image_path})
71
+ assert isinstance(result, str)
72
+ print("\nExtract Text Test Result:", result)
73
+
74
+
75
+ def test_analyze_csv_file(test_dir):
76
+ """Test CSV analysis functionality."""
77
+ # Create a test CSV file
78
+ file_path = "test_files/customers-100.csv"
79
+
80
+ result = analyze_csv_file.invoke(
81
+ {
82
+ "file_path": file_path,
83
+ "query": "What is the first name of the first customer?",
84
+ }
85
+ )
86
+ assert isinstance(result, str)
87
+ assert "CSV file loaded" in result
88
+ print("\nAnalyze CSV Test Result:", result)
89
+
90
+
91
+ def test_analyze_excel_file():
92
+ """Test Excel analysis functionality."""
93
+ excel_path = "test_files/Project-Management-Sample-Data.xlsx"
94
+ result = analyze_excel_file.invoke(
95
+ {
96
+ "file_path": excel_path,
97
+ "query": "What is the name of the first task?",
98
+ }
99
+ )
100
+ assert isinstance(result, str)
101
+ assert "Excel file loaded" in result
102
+ print("\nAnalyze Excel Test Result:", result)
103
+
104
+
105
+ def test_transcribe_audio():
106
+ """Test audio transcription functionality."""
107
+ audio_path = "test_files/CECIL-I-NEED-YOU-CECIL.mp3"
108
+ result = transcribe_audio.invoke({"audio_file_path": audio_path})
109
+ assert isinstance(result, str)
110
+ print("\nTranscribe Audio Test Result:", result)
111
+
112
+
113
+ def test_transcribe_youtube():
114
+ """Test YouTube transcription functionality."""
115
+ url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" # Example video
116
+ result = transcribe_youtube.invoke({"url": url})
117
+ assert isinstance(result, str)
118
+ print("\nTranscribe YouTube Test Result:", result[:200])
119
+
120
+
121
+ def test_video_frames_to_images():
122
+ """Test video frame extraction functionality."""
123
+ url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" # Example video
124
+
125
+ with tempfile.TemporaryDirectory() as temp_dir:
126
+ result = video_frames_to_images.invoke(
127
+ {"url": url, "folder_name": temp_dir, "sample_interval_seconds": 5}
128
+ )
129
+ assert isinstance(result, list)
130
+ assert all(isinstance(path, str) for path in result)
131
+ assert os.path.exists(os.path.join(temp_dir, "frames"))
132
+ assert len(os.listdir(os.path.join(temp_dir, "frames"))) == len(result)
133
+ print(f"\nVideo Frames Test Result: Extracted {len(result)} frames")
134
+
135
+
136
+ def test_use_vision_model():
137
+ """Test vision model functionality."""
138
+ image_paths = ["test_files/people.jpeg", "test_files/text_in_image.jpg"]
139
+ result = use_vision_model.invoke(
140
+ {
141
+ "question": "What do you see in these images?",
142
+ "image_paths": image_paths,
143
+ }
144
+ )
145
+ assert isinstance(result, str)
146
+ print("\nVision Model Test Result:", result)
147
+
148
+
149
+ def test_review_youtube_video():
150
+ """Test YouTube video review functionality."""
151
+ url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" # Example video
152
+ question = "What is the main topic of this video?"
153
+ result = review_youtube_video.invoke({"url": url, "question": question})
154
+ assert isinstance(result, str)
155
+ print("\nReview YouTube Test Result:", result)