Spaces:
Runtime error
Runtime error
Upload 6 files
Browse filesAssignment final
app.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
import requests
|
| 4 |
-
import inspect
|
| 5 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# (Keep Constants as is)
|
| 8 |
# --- Constants ---
|
|
@@ -15,9 +17,27 @@ class BasicAgent:
|
|
| 15 |
print("BasicAgent initialized.")
|
| 16 |
def __call__(self, question: str) -> str:
|
| 17 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
| 23 |
"""
|
|
@@ -40,7 +60,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 40 |
|
| 41 |
# 1. Instantiate Agent ( modify this part to create your agent)
|
| 42 |
try:
|
| 43 |
-
agent = BasicAgent()
|
| 44 |
except Exception as e:
|
| 45 |
print(f"Error instantiating agent: {e}")
|
| 46 |
return f"Error initializing agent: {e}", None
|
|
@@ -79,8 +99,9 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 79 |
if not task_id or question_text is None:
|
| 80 |
print(f"Skipping item with missing task_id or question: {item}")
|
| 81 |
continue
|
| 82 |
-
try:
|
| 83 |
-
|
|
|
|
| 84 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
| 85 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
| 86 |
except Exception as e:
|
|
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
import requests
|
|
|
|
| 4 |
import pandas as pd
|
| 5 |
+
from graph import graph
|
| 6 |
+
from langchain_core.messages import HumanMessage
|
| 7 |
+
from state import QuestionState
|
| 8 |
|
| 9 |
# (Keep Constants as is)
|
| 10 |
# --- Constants ---
|
|
|
|
| 17 |
print("BasicAgent initialized.")
|
| 18 |
def __call__(self, question: str) -> str:
|
| 19 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
| 20 |
+
intermediate = graph.invoke({"messages": [HumanMessage(content=question)]})
|
| 21 |
+
answer = intermediate['messages'][-1].content
|
| 22 |
+
print(f"Agent returning Gemini answer: {answer}")
|
| 23 |
+
return answer
|
| 24 |
+
question_number = 0
|
| 25 |
+
class QuestionAgent:
|
| 26 |
+
def __init__(self):
|
| 27 |
+
print("QuestionAgent initialized.")
|
| 28 |
+
def __call__(self, inquestion: str, inattachment_url: str) -> str:
|
| 29 |
+
global question_number
|
| 30 |
+
question_number += 1
|
| 31 |
+
print(f"{question_number} Agent received question: {inquestion}...")
|
| 32 |
+
state = QuestionState(
|
| 33 |
+
question = inquestion,
|
| 34 |
+
attachment_url = inattachment_url,
|
| 35 |
+
messages = [HumanMessage(content=inquestion)]
|
| 36 |
+
)
|
| 37 |
+
intermediate = graph.invoke(state)
|
| 38 |
+
answer = intermediate['messages'][-1].content
|
| 39 |
+
print(f"Agent returning Gemini answer: {answer}")
|
| 40 |
+
return answer
|
| 41 |
|
| 42 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
| 43 |
"""
|
|
|
|
| 60 |
|
| 61 |
# 1. Instantiate Agent ( modify this part to create your agent)
|
| 62 |
try:
|
| 63 |
+
agent = QuestionAgent() #BasicAgent()
|
| 64 |
except Exception as e:
|
| 65 |
print(f"Error instantiating agent: {e}")
|
| 66 |
return f"Error initializing agent: {e}", None
|
|
|
|
| 99 |
if not task_id or question_text is None:
|
| 100 |
print(f"Skipping item with missing task_id or question: {item}")
|
| 101 |
continue
|
| 102 |
+
try:
|
| 103 |
+
attachment_url = f"{api_url}/files/{task_id}"
|
| 104 |
+
submitted_answer = agent(question_text, attachment_url)
|
| 105 |
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
| 106 |
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
| 107 |
except Exception as e:
|
graph.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import google.generativeai as genai
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
| 5 |
+
from IPython.display import Image, display
|
| 6 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 7 |
+
from langgraph.graph import StateGraph, START, END
|
| 8 |
+
from langgraph.prebuilt import ToolNode
|
| 9 |
+
from langgraph.prebuilt import tools_condition
|
| 10 |
+
from pprint import pprint
|
| 11 |
+
from state import QuestionState
|
| 12 |
+
from tools import search_web, search_wikipedia, get_image_attachment, get_audio_attachment, get_youtube_transcript, get_excel_attachment, get_attachment
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
# Set up Gemini API key from environment variable
|
| 16 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
| 17 |
+
if not GEMINI_API_KEY:
|
| 18 |
+
raise ValueError("Please set the GEMINI_API_KEY environment variable.")
|
| 19 |
+
|
| 20 |
+
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
|
| 21 |
+
|
| 22 |
+
"""
|
| 23 |
+
models = genai.list_models()
|
| 24 |
+
for m in models:
|
| 25 |
+
print(m.name)
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
# LLM
|
| 29 |
+
llm = ChatGoogleGenerativeAI(
|
| 30 |
+
model="gemini-2.5-pro-preview-03-25", # or "gemini-1.5-pro", "gemini-2.0-flash-001", etc.
|
| 31 |
+
temperature=0.7,
|
| 32 |
+
google_api_key=GEMINI_API_KEY # or set the GOOGLE_API_KEY environment variable
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# TOOLS
|
| 36 |
+
tools = [search_web, search_wikipedia, get_image_attachment, get_audio_attachment, get_youtube_transcript, get_excel_attachment, get_attachment]
|
| 37 |
+
llm_with_tools = llm.bind_tools(tools)
|
| 38 |
+
|
| 39 |
+
sys_msg = SystemMessage(content="You are answering questions from the GAIA benchmark. You have access to tools to search the web and wikipedia to answer these questions. Answer with the shortest most concise response possible.You are a general AI assistant. I will ask you a question. Your answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. If the question refers to an attachment or media not included in the question, such as 'what is in this picture' get the attachment and use it to answer the question from this url {attachment_url}")
|
| 40 |
+
# GRAPH
|
| 41 |
+
|
| 42 |
+
# Node
|
| 43 |
+
def assistant(state: QuestionState):
|
| 44 |
+
systemPrompt = sys_msg.content.format(attachment_url=state["attachment_url"])
|
| 45 |
+
inputs = [systemPrompt] + state["messages"]
|
| 46 |
+
return {"messages": [llm_with_tools.invoke(inputs)]}
|
| 47 |
+
|
| 48 |
+
# Build graph
|
| 49 |
+
builder = StateGraph(QuestionState)
|
| 50 |
+
builder.add_node("assistant", assistant)
|
| 51 |
+
builder.add_node("tools", ToolNode(tools))
|
| 52 |
+
|
| 53 |
+
builder.add_edge(START, "assistant")
|
| 54 |
+
builder.add_conditional_edges(
|
| 55 |
+
"assistant",
|
| 56 |
+
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
|
| 57 |
+
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
|
| 58 |
+
tools_condition,
|
| 59 |
+
)
|
| 60 |
+
builder.add_edge("tools", "assistant")
|
| 61 |
+
graph = builder.compile()
|
| 62 |
+
|
| 63 |
+
message = {
|
| 64 |
+
"role": "user",
|
| 65 |
+
"content": [
|
| 66 |
+
{
|
| 67 |
+
"type": "text",
|
| 68 |
+
"text": "What is in the attached image?",
|
| 69 |
+
}
|
| 70 |
+
],
|
| 71 |
+
}
|
| 72 |
+
testState = QuestionState(
|
| 73 |
+
question="What is int the attached image?",
|
| 74 |
+
attachment_url="https://img.freepik.com/premium-photo/cyberpunk-anime-robot-girl_1162089-424.jpg",
|
| 75 |
+
messages=[message] #HumanMessage(content="What is int the attached image?")]
|
| 76 |
+
)
|
| 77 |
+
result = graph.invoke(testState)
|
| 78 |
+
print(result['messages'][-1].content)
|
| 79 |
+
# View
|
| 80 |
+
#display(Image(graph.get_graph().draw_mermaid_png()))
|
| 81 |
+
'''
|
| 82 |
+
messages = [HumanMessage(content="Hello, what is 2 multiplied by 2?")]
|
| 83 |
+
messages = graph.invoke({"messages": messages})
|
| 84 |
+
for m in messages['messages']:
|
| 85 |
+
m.pretty_print()
|
| 86 |
+
|
| 87 |
+
messages = [AIMessage(content=f"So you said you were researching ocean mammals?", name="Model")]
|
| 88 |
+
messages.append(HumanMessage(content=f"Yes, that's right.",name="Lance"))
|
| 89 |
+
messages.append(AIMessage(content=f"Great, what would you like to learn about.", name="Model"))
|
| 90 |
+
messages.append(HumanMessage(content=f"I want to learn about the best place to see Orcas in the US.", name="Lance"))
|
| 91 |
+
|
| 92 |
+
for m in messages:
|
| 93 |
+
m.pretty_print()
|
| 94 |
+
|
| 95 |
+
result = llm.invoke(messages)
|
| 96 |
+
type(result)
|
| 97 |
+
result.pretty_print()
|
| 98 |
+
def gemini_llm(state):
|
| 99 |
+
prompt = state["question"]
|
| 100 |
+
model = genai.GenerativeModel("gemini-pro")
|
| 101 |
+
response = model.generate_content(prompt)
|
| 102 |
+
return {"question": prompt, "answer": response.text}
|
| 103 |
+
|
| 104 |
+
# Define the state schema
|
| 105 |
+
state_schema = {"question": str, "answer": str}
|
| 106 |
+
|
| 107 |
+
# Build the LangGraph graph
|
| 108 |
+
graph = StateGraph(state_schema)
|
| 109 |
+
graph.add_node("gemini", gemini_llm)
|
| 110 |
+
graph.set_entry_point("gemini")
|
| 111 |
+
graph.add_edge("gemini", END)
|
| 112 |
+
graph = graph.compile()
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
user_prompt = input("Enter your prompt: ")
|
| 116 |
+
result = graph.invoke({"question": user_prompt, "answer": ""})
|
| 117 |
+
print("Gemini response:", result["answer"])
|
| 118 |
+
'''
|
state.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import MessagesState
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
class QuestionState(MessagesState):
|
| 5 |
+
task_id: str
|
| 6 |
+
question: str
|
| 7 |
+
attachment_url: str
|
| 8 |
+
attachment_type: str
|
| 9 |
+
|
tools.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.tools import TavilySearchResults, tool
|
| 2 |
+
from langchain_community.document_loaders import WikipediaLoader, YoutubeLoader
|
| 3 |
+
from langchain_core.messages import SystemMessage
|
| 4 |
+
from state import QuestionState
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
|
| 7 |
+
import mimetypes
|
| 8 |
+
import logging
|
| 9 |
+
import io
|
| 10 |
+
import requests
|
| 11 |
+
import re
|
| 12 |
+
from state import QuestionState
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
from state import QuestionState
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
# --- Configure Logging (Optional but Recommended) ---
|
| 19 |
+
logging.basicConfig(level=logging.INFO)
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 23 |
+
|
| 24 |
+
# Search query writing
|
| 25 |
+
search_instructions = SystemMessage(content=f"""Search the internet to find relevant answers to queries""")
|
| 26 |
+
|
| 27 |
+
def search_web(state: QuestionState):
|
| 28 |
+
""" Retrieve docs from web search """
|
| 29 |
+
|
| 30 |
+
logger.info("Tool called: search_web")
|
| 31 |
+
# Search
|
| 32 |
+
tavily_search = TavilySearchResults(max_results=3)
|
| 33 |
+
|
| 34 |
+
# Search query
|
| 35 |
+
structured_llm = llm.with_structured_output(SearchQuery)
|
| 36 |
+
search_query = structured_llm.invoke([search_instructions]+state['messages'])
|
| 37 |
+
|
| 38 |
+
# Search
|
| 39 |
+
search_docs = tavily_search.invoke(search_query.search_query)
|
| 40 |
+
|
| 41 |
+
# Format
|
| 42 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
| 43 |
+
[
|
| 44 |
+
f'<Document href="{doc["url"]}"/>\n{doc["content"]}\n</Document>'
|
| 45 |
+
for doc in search_docs
|
| 46 |
+
]
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
return {"context": [formatted_search_docs]}
|
| 50 |
+
|
| 51 |
+
def search_wikipedia(state: QuestionState):
|
| 52 |
+
""" Retrieve docs from wikipedia """
|
| 53 |
+
|
| 54 |
+
logger.info("Tool called: search_wikipedia")
|
| 55 |
+
# Search query
|
| 56 |
+
structured_llm = llm.with_structured_output(SearchQuery)
|
| 57 |
+
search_query = structured_llm.invoke([search_instructions]+state['messages'])
|
| 58 |
+
|
| 59 |
+
# Search
|
| 60 |
+
search_docs = WikipediaLoader(query=search_query.search_query,
|
| 61 |
+
load_max_docs=2).load()
|
| 62 |
+
|
| 63 |
+
# Format
|
| 64 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
| 65 |
+
[
|
| 66 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
| 67 |
+
for doc in search_docs
|
| 68 |
+
]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return {"context": [formatted_search_docs]}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_image_attachment(state: QuestionState):
|
| 75 |
+
""" Retrieve image attachment for the current question """
|
| 76 |
+
logger.info("Tool called: get_image_attachment")
|
| 77 |
+
response = _download_with_retries(state["attachment_url"])
|
| 78 |
+
if response is None:
|
| 79 |
+
logger.error(f"Failed to download image after retries: {state['attachment_url']}")
|
| 80 |
+
return None
|
| 81 |
+
try:
|
| 82 |
+
image_data = base64.b64encode(response.content).decode("utf-8")
|
| 83 |
+
except Exception as e:
|
| 84 |
+
logger.error(f"An error occurred while trying to process the image attachment: {e}")
|
| 85 |
+
return None
|
| 86 |
+
content_type = response.headers.get('content-type') or mimetypes.guess_type(state["attachment_url"])[0]
|
| 87 |
+
# try to guess the content type
|
| 88 |
+
if content_type is None:
|
| 89 |
+
content_type = mimetypes.guess_type(state["attachment_url"])[0] or 'image/jpeg'
|
| 90 |
+
return f"data:{content_type};base64,{image_data}"
|
| 91 |
+
|
| 92 |
+
def get_audio_attachment(state: QuestionState):
|
| 93 |
+
""" Retrieve audio attachment for the current question """
|
| 94 |
+
logger.info("Tool called: get_audio_attachment at " + state["attachment_url"])
|
| 95 |
+
response = _download_with_retries(state["attachment_url"], stream=True)
|
| 96 |
+
if response is None:
|
| 97 |
+
logger.error(f"Failed to download audio after retries: {state['attachment_url']}")
|
| 98 |
+
return None
|
| 99 |
+
logger.info("The Audio file " + {response.content-type} + " downloaded successfully")
|
| 100 |
+
audio_data = base64.b64encode(response.content).decode("utf-8")
|
| 101 |
+
content_type = response.headers.get('content-type') or mimetypes.guess_type(state["attachment_url"])[0]
|
| 102 |
+
return f"data:{content_type};base64,{audio_data}"
|
| 103 |
+
|
| 104 |
+
def get_excel_attachment(state: QuestionState):
|
| 105 |
+
""" Retrieve excel attachment for the current question """
|
| 106 |
+
logger.info("Tool called: get_excel_attachment")
|
| 107 |
+
response = _download_with_retries(state["attachment_url"], stream=True)
|
| 108 |
+
if response is None:
|
| 109 |
+
logger.error(f"Failed to download excel after retries: {state['attachment_url']}")
|
| 110 |
+
return None, None
|
| 111 |
+
excel_bytes = response.content
|
| 112 |
+
return excel_bytes, response.headers.get('Content-Type')
|
| 113 |
+
|
| 114 |
+
def get_attachment(state: QuestionState):
|
| 115 |
+
""" Retrieve attachment for the current question if a more specific attachment tool is not available"""
|
| 116 |
+
logger.info("Tool called: get_attachment")
|
| 117 |
+
response = _download_with_retries(state["attachment_url"], stream=True)
|
| 118 |
+
if response is None:
|
| 119 |
+
logger.error(f"Failed to download attachment after retries: {state['attachment_url']}")
|
| 120 |
+
return None, None
|
| 121 |
+
attachment_bytes = response.content
|
| 122 |
+
return attachment_bytes, response.headers.get('Content-Type')
|
| 123 |
+
|
| 124 |
+
# --- Helper Function to Extract Video ID ---
|
| 125 |
+
|
| 126 |
+
def _download_with_retries(url, stream=False, retries=5, timeout=10):
|
| 127 |
+
"""Helper function to download a file with retries and logging."""
|
| 128 |
+
for attempt in range(1, retries + 1):
|
| 129 |
+
try:
|
| 130 |
+
logger.info(f"Attempt {attempt} downloading: {url}")
|
| 131 |
+
response = requests.get(url, stream=stream, timeout=timeout)
|
| 132 |
+
response.raise_for_status()
|
| 133 |
+
return response
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.warning(f"Download failed (attempt {attempt}) for {url}: {e}")
|
| 136 |
+
logger.error(f"All {retries} attempts failed for download: {url}")
|
| 137 |
+
return None
|
| 138 |
+
|
| 139 |
+
def extract_video_id(url: str) -> str | None:
|
| 140 |
+
"""Extracts the YouTube video ID from various URL formats."""
|
| 141 |
+
# Regex patterns to cover common YouTube URL formats
|
| 142 |
+
patterns = [
|
| 143 |
+
r'(?:https?:\/\/)?(?:www\.)?youtube\.com\/watch\?v=([a-zA-Z0-9_-]{11})', # Standard watch URL
|
| 144 |
+
r'(?:https?:\/\/)?(?:www\.)?youtu\.be\/([a-zA-Z0-9_-]{11})', # Shortened youtu.be URL
|
| 145 |
+
r'(?:https?:\/\/)?(?:www\.)?youtube\.com\/embed\/([a-zA-Z0-9_-]{11})', # Embed URL
|
| 146 |
+
r'(?:https?:\/\/)?(?:www\.)?youtube\.com\/v\/([a-zA-Z0-9_-]{11})', # V URL (older format)
|
| 147 |
+
r'([a-zA-Z0-9_-]{11})' # Attempt to match just an ID (less reliable)
|
| 148 |
+
]
|
| 149 |
+
for pattern in patterns:
|
| 150 |
+
match = re.search(pattern, url)
|
| 151 |
+
if match:
|
| 152 |
+
logger.info(f"Extracted video ID: {match.group(1)}")
|
| 153 |
+
return match.group(1)
|
| 154 |
+
logger.warning(f"Could not extract video ID from URL: {url}")
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
# --- Direct Transcript Fetching Function ---
|
| 158 |
+
def get_youtube_transcript(youtube_url: str) -> str | None:
|
| 159 |
+
"""
|
| 160 |
+
Retrieves the transcript for a YouTube video directly using youtube-transcript-api.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
youtube_url: The URL of the YouTube video.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
The transcript as a single string, or None if an error occurs.
|
| 167 |
+
"""
|
| 168 |
+
logger.info("Tool called: get_youtube_transcript")
|
| 169 |
+
video_id = extract_video_id(youtube_url)
|
| 170 |
+
if not video_id:
|
| 171 |
+
logger.error("Invalid YouTube URL or could not extract Video ID.")
|
| 172 |
+
return None # Return None for error, indicating failure
|
| 173 |
+
|
| 174 |
+
try:
|
| 175 |
+
logger.info(f"Fetching transcript for video ID: {video_id}")
|
| 176 |
+
# Fetch the transcript (defaults to English, can specify languages)
|
| 177 |
+
transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
|
| 178 |
+
|
| 179 |
+
# Combine the transcript text parts into a single string
|
| 180 |
+
transcript = " ".join([item['text'] for item in transcript_list])
|
| 181 |
+
logger.info(f"Transcript fetched successfully (length: {len(transcript)} chars).")
|
| 182 |
+
return transcript
|
| 183 |
+
|
| 184 |
+
except TranscriptsDisabled:
|
| 185 |
+
logger.error(f"Transcripts are disabled for video: {youtube_url}")
|
| 186 |
+
return None
|
| 187 |
+
except NoTranscriptFound:
|
| 188 |
+
logger.error(f"No transcript found for video: {youtube_url}. Might be unavailable or in an unsupported language.")
|
| 189 |
+
return None
|
| 190 |
+
except Exception as e:
|
| 191 |
+
# Catch any other unexpected errors (network, API changes, etc.)
|
| 192 |
+
logger.error(f"An unexpected error occurred fetching transcript for {youtube_url}: {e}", exc_info=True)
|
| 193 |
+
return None
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# test_url_with_transcript = "https://www.youtube.com/watch?v=dQw4w9WgXcQ" # Example (Rick Astley)
|
| 197 |
+
# test_url_no_transcript = "https://www.youtube.com/watch?v=some_video_without_transcripts" # Placeholder
|
| 198 |
+
# test_url_invalid = "htp:/invalid-url"
|
| 199 |
+
|
| 200 |
+
# print(f"\nTesting URL: {test_url_with_transcript}")
|
| 201 |
+
# transcript1 = get_youtube_transcript_direct(test_url_with_transcript)
|
| 202 |
+
# if transcript1:
|
| 203 |
+
# print("Transcript (first 500 chars):", transcript1[:500])
|
| 204 |
+
# else:
|
| 205 |
+
# print("Failed to get transcript.")
|
| 206 |
+
|
| 207 |
+
# print(f"\nTesting URL: {test_url_no_transcript}") # Uncomment to test known non-transcript video
|
| 208 |
+
# transcript2 = get_youtube_transcript_direct(test_url_no_transcript)
|
| 209 |
+
# if transcript2:
|
| 210 |
+
# print("Transcript:", transcript2[:500])
|
| 211 |
+
# else:
|
| 212 |
+
# print("Failed to get transcript.")
|
| 213 |
+
|
| 214 |
+
# print(f"\nTesting URL: {test_url_invalid}")
|
| 215 |
+
# transcript3 = get_youtube_transcript_direct(test_url_invalid)
|
| 216 |
+
# if transcript3:
|
| 217 |
+
# print("Transcript:", transcript3[:500])
|
| 218 |
+
# else:
|
| 219 |
+
# print("Failed to get transcript.")
|
| 220 |
+
|
| 221 |
+
"""
|
| 222 |
+
def get_audio_attachment(state: QuestionState):
|
| 223 |
+
response = requests.get(state["attachment_url"], stream=True)
|
| 224 |
+
response.raise_for_status()
|
| 225 |
+
audio_bytes = response.content
|
| 226 |
+
return audio_bytes, response.headers.get('Content-Type')
|
| 227 |
+
"""
|
| 228 |
+
# def load_attachment_for_llm(url):
|
| 229 |
+
# response = requests.get(url)
|
| 230 |
+
# content_type = response.headers.get('content-type') or mimetypes.guess_type(url)[0]
|
| 231 |
+
# if content_type:
|
| 232 |
+
# if content_type.startswith('image/'):
|
| 233 |
+
# return Image.open(io.BytesIO(response.content))
|
| 234 |
+
# elif content_type.startswith('audio/'):ou
|
| 235 |
+
# return io.BytesIO(response.content)
|
| 236 |
+
# elif content_type.startswith('text/'):
|
| 237 |
+
# return response.text
|
| 238 |
+
# # Add more handlers as needed (e.g., PDF, Excel)
|
| 239 |
+
# # Fallback: return bytes
|
| 240 |
+
# return io.BytesIO(response.content)
|
| 241 |
+
|
| 242 |
+
# def get_attachment(state: QuestionState):
|
| 243 |
+
# # """Retrieves and loads the attachment for the current question."""
|
| 244 |
+
# api_url = DEFAULT_API_URL
|
| 245 |
+
# attachment_url = f"{api_url}/files/{state.task_id}"
|
| 246 |
+
# # Store the URL in the state
|
| 247 |
+
# state.attachment_url = attachment_url
|
| 248 |
+
# # Load the attachment (image, audio, text, etc.)
|
| 249 |
+
# attachment = load_attachment_for_llm(attachment_url)
|
| 250 |
+
# # Store the loaded attachment in the state
|
| 251 |
+
# state.attachment = attachment
|
| 252 |
+
# # Return updated fields as a dict (LangGraph expects this)
|
| 253 |
+
# return {
|
| 254 |
+
# "attachment_url": attachment_url,
|
| 255 |
+
# "attachment": attachment
|
| 256 |
+
# }
|
| 257 |
+
# return {"attachment": io.BytesIO(response.content)}
|