agentsList.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, Optional
2
+ from langgraph.graph import StateGraph, START, END
3
+ from langchain_openai import ChatOpenAI
4
+ from langchain_core.messages import HumanMessage
5
+ from rich.console import Console
6
+ from smolagents import (
7
+ CodeAgent,
8
+ ToolCallingAgent,
9
+ OpenAIServerModel,
10
+ AgentLogger,
11
+ LogLevel,
12
+ Panel,
13
+ Text,
14
+ )
15
+ from tools import (
16
+ GetAttachmentTool,
17
+ GoogleSearchTool,
18
+ GoogleSiteSearchTool,
19
+ ContentRetrieverTool,
20
+ YoutubeVideoTool,
21
+ SpeechRecognitionTool,
22
+ ClassifierTool,
23
+ ImageToChessBoardFENTool,
24
+ chess_engine_locator,
25
+ )
26
+ import openai
27
+ import backoff
28
+
29
+
30
+
31
+ def create_genai_agent(verbosity: int = LogLevel.INFO):
32
+ get_attachment_tool = GetAttachmentTool()
33
+ speech_recognition_tool = SpeechRecognitionTool()
34
+ env_tools = [
35
+ get_attachment_tool,
36
+ ]
37
+ model = OpenAIServerModel(model_id="gpt-4.1")
38
+ console = Console(record=True)
39
+ logger = AgentLogger(level=verbosity, console=console)
40
+ steps_buffer = []
41
+
42
+
43
+ def capture_step_log(agent) -> None:
44
+ steps_buffer.append(console.export_text(clear=True))
45
+
46
+
47
+ agents = {
48
+ agent.name: agent
49
+ for agent in [
50
+ ToolCallingAgent(
51
+ name="general_assistant",
52
+ description="Answers questions for best of knowledge and common reasoning grounded on already known information. Can understand multimedia including audio and video files and YouTube.",
53
+ model=model,
54
+ tools=env_tools
55
+ + [
56
+ speech_recognition_tool,
57
+ YoutubeVideoTool(
58
+ client=model.client,
59
+ speech_recognition_tool=speech_recognition_tool,
60
+ frames_interval=3,
61
+ chunk_duration=60,
62
+ debug=True,
63
+ ),
64
+ ClassifierTool(
65
+ client=model.client,
66
+ model_id="gpt-4.1-mini",
67
+ ),
68
+ ],
69
+ logger=logger,
70
+ step_callbacks=[capture_step_log],
71
+ ),
72
+ ToolCallingAgent(
73
+ name="web_researcher",
74
+ description="Answers questions that require grounding in unknown information through search on web sites and other online resources.",
75
+ tools=env_tools
76
+ + [
77
+ GoogleSearchTool(),
78
+ GoogleSiteSearchTool(),
79
+ ContentRetrieverTool(),
80
+ ],
81
+ model=model,
82
+ planning_interval=3,
83
+ max_steps=9,
84
+ logger=logger,
85
+ step_callbacks=[capture_step_log],
86
+ ),
87
+ CodeAgent(
88
+ name="data_analyst",
89
+ description="Data analyst with advanced skills in statistic, handling tabular data and related Python packages.",
90
+ tools=env_tools,
91
+ additional_authorized_imports=[
92
+ "numpy",
93
+ "pandas",
94
+ "tabulate",
95
+ "matplotlib",
96
+ "seaborn",
97
+ ],
98
+ model=model,
99
+ logger=logger,
100
+ step_callbacks=[capture_step_log],
101
+ ),
102
+ CodeAgent(
103
+ name="chess_player",
104
+ description="Chess grandmaster empowered by chess engine. Always thinks at least 100 steps ahead.",
105
+ tools=env_tools
106
+ + [
107
+ ImageToChessBoardFENTool(client=model.client),
108
+ chess_engine_locator,
109
+ ],
110
+ additional_authorized_imports=[
111
+ "chess",
112
+ "chess.engine",
113
+ ],
114
+ model=model,
115
+ logger=logger,
116
+ step_callbacks=[capture_step_log],
117
+ ),
118
+ ]
119
+ }
120
+
121
+
122
+ class GAIATask(TypedDict):
123
+ task_id: Optional[str | None] = None
124
+ question: str
125
+ steps: list[str] = []
126
+ agent: Optional[str | None] = None
127
+ raw_answer: Optional[str | None] = None
128
+ final_answer: Optional[str | None] = None
129
+
130
+
131
+ llm = ChatOpenAI(model="gpt-4.1")
132
+ logger = AgentLogger(level=verbosity)
133
+
134
+
135
+ @backoff.on_exception(backoff.expo, openai.RateLimitError, max_time=60, max_tries=6)
136
+ def llm_invoke_with_retry(messages):
137
+ response = llm.invoke(messages)
138
+ return response
139
+
140
+
141
+ def read_question(state: GAIATask):
142
+ logger.log_task(
143
+ content=state["question"].strip(),
144
+ subtitle=f"LangGraph with {type(llm).__name__} - {llm.model_name}",
145
+ level=LogLevel.INFO,
146
+ title="Final Assignment Agent for Hugging Face Agents Course",
147
+ )
148
+ get_attachment_tool.attachment_for(state["task_id"])
149
+
150
+ return {
151
+ "steps": [],
152
+ "agent": None,
153
+ "raw_answer": None,
154
+ "final_answer": None,
155
+ }
156
+
157
+
158
+ def select_agent(state: GAIATask):
159
+ agents_description = "\n\n".join(
160
+ [
161
+ f"AGENT NAME: {a.name}\nAGENT DESCRIPTION: {a.description}"
162
+ for a in agents.values()
163
+ ]
164
+ )
165
+
166
+ prompt = f"""\
167
+ You are a general AI assistant.
168
+
169
+ I will provide you a question and a list of agents with their descriptions.
170
+ Your task is to select the most appropriate agent to answer the question.
171
+ You can select one of the agents or decide that no agent is needed.
172
+
173
+ If question has attachment only agent can answer it.
174
+
175
+ QUESTION:
176
+ {state["question"]}
177
+
178
+ {agents_description}
179
+
180
+ Now, return the name of the agent you selected or "no agent needed" if you think that no agent is needed.
181
+ """
182
+
183
+ response = llm_invoke_with_retry([HumanMessage(content=prompt)])
184
+ agent_name = response.content.strip()
185
+
186
+ if agent_name in agents:
187
+ logger.log(
188
+ f"Agent {agent_name} selected for solving the task.",
189
+ level=LogLevel.DEBUG,
190
+ )
191
+ return {
192
+ "agent": agent_name,
193
+ "steps": state.get("steps", [])
194
+ + [
195
+ f"Agent '{agent_name}' selected for task execution.",
196
+ ],
197
+ }
198
+ elif agent_name == "no agent needed":
199
+ logger.log(
200
+ "No appropriate agent found in the list. No agent will be used.",
201
+ level=LogLevel.DEBUG,
202
+ )
203
+ return {
204
+ "agent": None,
205
+ "steps": state.get("steps", [])
206
+ + [
207
+ "A decision is made to solve the task directly without invoking any agent.",
208
+ ],
209
+ }
210
+ else:
211
+ logger.log(
212
+ f"[bold red]Warning to user: Unexpected agent name '{agent_name}' selected. No agent will be used.[/bold red]",
213
+ level=LogLevel.INFO,
214
+ )
215
+ return {
216
+ "agent": None,
217
+ "steps": state.get("steps", [])
218
+ + [
219
+ f"Attempt to select non-existing agent '{agent_name}'. No agent will be used.",
220
+ ],
221
+ }
222
+
223
+
224
+ def delegate_to_agent(state: GAIATask):
225
+ agent_name = state.get("agent", None)
226
+ if not agent_name:
227
+ raise ValueError("Agent not selected.")
228
+ if agent_name not in agents:
229
+ raise ValueError(f"Agent '{agent_name}' is not available.")
230
+
231
+ logger.log(
232
+ Panel(Text(f"Calling agent: {agent_name}.")),
233
+ level=LogLevel.INFO,
234
+ )
235
+
236
+ agent = agents[agent_name]
237
+ agent_answer = agent.run(task=state["question"])
238
+ steps = [f"Agent '{agent_name}' step:\n{s}" for s in steps_buffer]
239
+ steps_buffer.clear()
240
+ return {
241
+ "raw_answer": agent_answer,
242
+ "steps": state.get("steps", []) + steps,
243
+ }
244
+
245
+
246
+ def one_shot_answering(state: GAIATask):
247
+ response = llm_invoke_with_retry([HumanMessage(content=state.get("question"))])
248
+ return {
249
+ "raw_answer": response.content,
250
+ "steps": state.get("steps", [])
251
+ + [
252
+ f"One-shot answer:\n{response.content}",
253
+ ],
254
+ }
255
+
256
+
257
+ def refine_answer(state: GAIATask):
258
+ question = state.get("question")
259
+ answer = state.get("raw_answer", None)
260
+ if not answer:
261
+ return {"final_answer": "No answer."}
262
+
263
+ prompt = f"""\
264
+ You are a general AI assistant.
265
+
266
+ I will provide you a question and correct answer to it. Answer is correct but may be too verbose or not follow the rules below.
267
+ Your task is to rephrase answer according to rules below.
268
+
269
+ Answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
270
+
271
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
272
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
273
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
274
+
275
+ If you are asked for a comma separated list, use space after comma and before next element of the list unless other directly specified in a question.
276
+ Check question context to define if letters case matters. Do not change case if not prescribed by other rules or question.
277
+ If you are not asked for the list, capitalize the first letter of the answer unless it changes meaning of the answer.
278
+ If answer is number, use digits only not words unless other directly specified in a question.
279
+ If answer is not full sentence, do not add period at the end.
280
+
281
+ Preserve all items if the answer is a list.
282
+
283
+ QUESTION:
284
+ {question}
285
+
286
+ ANSWER:
287
+ {answer}
288
+ """
289
+ response = llm_invoke_with_retry([HumanMessage(content=prompt)])
290
+ refined_answer = response.content.strip()
291
+ logger.log(
292
+ Text(f"GAIA final answer: {refined_answer}", style="bold #d4b702"),
293
+ level=LogLevel.INFO,
294
+ )
295
+ return {
296
+ "final_answer": refined_answer,
297
+ "steps": state.get("steps", [])
298
+ + [
299
+ "Refining the answer according to GAIA benchmark rules.",
300
+ f"FINAL ANSWER: {response.content}",
301
+ ],
302
+ }
303
+
304
+
305
+ def route_task(state: GAIATask) -> str:
306
+ if state.get("agent") in agents:
307
+ return "agent selected"
308
+ else:
309
+ return "no agent matched"
310
+
311
+
312
+ # Create the graph
313
+ gaia_graph = StateGraph(GAIATask)
314
+
315
+ # Add nodes
316
+ gaia_graph.add_node("read_question", read_question)
317
+ gaia_graph.add_node("select_agent", select_agent)
318
+ gaia_graph.add_node("delegate_to_agent", delegate_to_agent)
319
+ gaia_graph.add_node("one_shot_answering", one_shot_answering)
320
+ gaia_graph.add_node("refine_answer", refine_answer)
321
+
322
+ # Start the edges
323
+ gaia_graph.add_edge(START, "read_question")
324
+ # Add edges - defining the flow
325
+ gaia_graph.add_edge("read_question", "select_agent")
326
+
327
+ # Add conditional branching from select_agent
328
+ gaia_graph.add_conditional_edges(
329
+ "select_agent",
330
+ route_task,
331
+ {"agent selected": "delegate_to_agent", "no agent matched": "one_shot_answering"},
332
+ )
333
+
334
+ # Add the final edges
335
+ gaia_graph.add_edge("delegate_to_agent", "refine_answer")
336
+ gaia_graph.add_edge("one_shot_answering", "refine_answer")
337
+ gaia_graph.add_edge("refine_answer", END)
338
+
339
+ gaia = gaia_graph.compile()
340
+ return gaia
app.py CHANGED
@@ -3,21 +3,26 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
  # --- Basic Agent Definition ---
12
- # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
  class BasicAgent:
14
  def __init__(self):
 
15
  print("BasicAgent initialized.")
16
  def __call__(self, question: str) -> str:
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
@@ -34,6 +39,18 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
34
  print("User not logged in.")
35
  return "Please Login to Hugging Face with the button.", None
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  api_url = DEFAULT_API_URL
38
  questions_url = f"{api_url}/questions"
39
  submit_url = f"{api_url}/submit"
@@ -80,7 +97,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
80
  print(f"Skipping item with missing task_id or question: {item}")
81
  continue
82
  try:
83
- submitted_answer = agent(question_text)
84
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
85
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
86
  except Exception as e:
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ import agentsList
7
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
10
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
11
 
12
  # --- Basic Agent Definition ---
 
13
  class BasicAgent:
14
  def __init__(self):
15
+ self.genaiAgent = agentsList.create_genai_agent()
16
  print("BasicAgent initialized.")
17
  def __call__(self, question: str) -> str:
18
  print(f"Agent received question (first 50 chars): {question[:50]}...")
19
+ task = self.genaiAgent.invoke({
20
+ "task_id": task_id,
21
+ "question": question,
22
+ })
23
+ final_answer = task.get("final_answer")
24
+ print(f"Agent returning fixed answer: {final_answer}")
25
+ return task["final_answer"]
26
 
27
  def run_and_submit_all( profile: gr.OAuthProfile | None):
28
  """
 
39
  print("User not logged in.")
40
  return "Please Login to Hugging Face with the button.", None
41
 
42
+ # --- Allow only space owner to run agent to avoid misuse ---
43
+ if not space_id.startswith(username.strip()):
44
+ print("User is not an owner of the space. Please duplicate space and configure OPENAI_API_KEY, HF_TOKEN, GOOGLE_SEARCH_API_KEY, and GOOGLE_SEARCH_ENGINE_ID environment variables.")
45
+ return "Please duplicate space to your account to run the agent.", None
46
+
47
+ # --- Check for required environment variables ---
48
+ required_env_vars = ["OPENAI_API_KEY", "HF_TOKEN", "GOOGLE_SEARCH_API_KEY", "GOOGLE_SEARCH_ENGINE_ID"]
49
+ missing_env_vars = [var for var in required_env_vars if not os.getenv(var)]
50
+ if missing_env_vars:
51
+ print(f"Missing environment variables: {', '.join(missing_env_vars)}")
52
+ return f"Missing environment variables: {', '.join(missing_env_vars)}", None
53
+
54
  api_url = DEFAULT_API_URL
55
  questions_url = f"{api_url}/questions"
56
  submit_url = f"{api_url}/submit"
 
97
  print(f"Skipping item with missing task_id or question: {item}")
98
  continue
99
  try:
100
+ submitted_answer = agent(task_id=task_id, question=question_text)
101
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
102
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
103
  except Exception as e:
tools/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .get_attachment_tool import GetAttachmentTool
2
+ from .google_search_tools import GoogleSearchTool, GoogleSiteSearchTool
3
+ from .content_retriever_tool import ContentRetrieverTool
4
+ from .speech_recognition_tool import SpeechRecognitionTool
5
+ from .youtube_video_tool import YoutubeVideoTool
6
+ from .classifier_tool import ClassifierTool
7
+ from .chess_tools import ImageToChessBoardFENTool, chess_engine_locator
8
+
9
+ __all__ = [
10
+ "GetAttachmentTool",
11
+ "GoogleSearchTool",
12
+ "GoogleSiteSearchTool",
13
+ "ContentRetrieverTool",
14
+ "SpeechRecognitionTool",
15
+ "YoutubeVideoTool",
16
+ "ClassifierTool",
17
+ "ImageToChessBoardFENTool",
18
+ "chess_engine_locator",
19
+ ]
tools/chess_tools.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool, tool
2
+ from openai import OpenAI
3
+ import shutil
4
+
5
+
6
+ @tool
7
+ def chess_engine_locator() -> str | None:
8
+ """
9
+ Get the path to the chess engine binary. Can be used with chess.engine.SimpleEngine.popen_uci function from chess.engine Python module.
10
+ Returns:
11
+ str: Path to the chess engine.
12
+ """
13
+ path = shutil.which("stockfish")
14
+ return path if path else None
15
+
16
+
17
+ class ImageToChessBoardFENTool(Tool):
18
+ name = "image_to_chess_board_fen"
19
+ description = """Convert a chessboard image to board part of the FEN."""
20
+ inputs = {
21
+ "image_url": {
22
+ "type": "string",
23
+ "description": "Public URL of the image (preferred) or base64 encoded image in data URL format.",
24
+ }
25
+ }
26
+ output_type = "string"
27
+
28
+ def __init__(self, client: OpenAI | None = None, **kwargs):
29
+ self.client = client if client is not None else OpenAI()
30
+ super().__init__(**kwargs)
31
+
32
+ def attachment_for(self, task_id: str | None):
33
+ self.task_id = task_id
34
+
35
+ def forward(self, image_url: str) -> str:
36
+ """
37
+ Convert a chessboard image to board part of the FEN.
38
+ Args:
39
+ image_url (str): Public URL of the image (preferred) or base64 encoded image in data URL format.
40
+ Returns:
41
+ str: Board part of the FEN.
42
+ """
43
+ client = self.client
44
+
45
+ response = client.responses.create(
46
+ model="gpt-4.1",
47
+ input=[
48
+ {
49
+ "role": "user",
50
+ "content": [
51
+ {
52
+ "type": "input_text",
53
+ "text": "Describe the position of the pieces on the chessboard from the image. Please, nothing else but description.",
54
+ },
55
+ {"type": "input_image", "image_url": image_url},
56
+ ],
57
+ }
58
+ ],
59
+ )
60
+
61
+ response = client.responses.create(
62
+ model="gpt-4.1",
63
+ input=[
64
+ {
65
+ "role": "user",
66
+ "content": [
67
+ {
68
+ "type": "input_text",
69
+ "text": "Describe the position of the pieces on the chessboard from the image. Please, nothing else but description.",
70
+ },
71
+ ],
72
+ }
73
+ ]
74
+ + response.output
75
+ + [
76
+ {
77
+ "role": "user",
78
+ "content": [
79
+ {
80
+ "type": "input_text",
81
+ "text": """\
82
+ Write down all positions with known pieces.
83
+ Use a standard one-letter code to name pieces.
84
+
85
+ It is important to use the correct case for piece code. Use upper case for white and lower case for black.
86
+ It is important to include information about all the mentioned positions.
87
+
88
+ Describe each position in a new line.
89
+ Follow format: <piece><position> (piece first, than position, no spaces)
90
+ Return nothing but lines with positions.
91
+ """,
92
+ },
93
+ ],
94
+ }
95
+ ],
96
+ )
97
+ board_pos = response.output_text
98
+
99
+ pos_dict = {}
100
+ for pos_str in board_pos.splitlines():
101
+ pos_str = pos_str.strip()
102
+ if len(pos_str) != 3:
103
+ continue
104
+ piece = pos_str[0]
105
+ pos = pos_str[1:3]
106
+ pos_dict[pos] = piece
107
+
108
+ board_fen = ""
109
+ for rank in range(8, 0, -1):
110
+ empty = 0
111
+ for file_c in range(ord("a"), ord("h") + 1):
112
+ file = chr(file_c)
113
+ square = file + str(rank)
114
+ if square in pos_dict:
115
+ if empty > 0:
116
+ board_fen += str(empty)
117
+ empty = 0
118
+ board_fen += pos_dict[square]
119
+ else:
120
+ empty += 1
121
+ if empty > 0:
122
+ board_fen += str(empty)
123
+ if rank != 1:
124
+ board_fen += "/"
125
+
126
+ return board_fen
tools/classifier_tool.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from openai import OpenAI
3
+
4
+
5
+ class ClassifierTool(Tool):
6
+ name = "open_classifier"
7
+ description = """Classifies given items into given categories from perspective of specific knowledge area."""
8
+ inputs = {
9
+ "knowledge_area": {
10
+ "type": "string",
11
+ "description": "The knowledge area that should be used for classification.",
12
+ },
13
+ "environment": { # context make models too verbose
14
+ "type": "string",
15
+ "description": "Couple words that describe environment or location in which items should be classified in case of plural meaning or if only part of item relevant for classification.",
16
+ },
17
+ "categories": {
18
+ "type": "string",
19
+ "description": "Comma separated list of categories to distribute objects.",
20
+ },
21
+ "items": {
22
+ "type": "string",
23
+ "description": "Comma separated list of items to be classified. Please include adjectives if available.",
24
+ },
25
+ }
26
+ output_type = "string"
27
+
28
+ def __init__(
29
+ self,
30
+ client: OpenAI | None = None,
31
+ model_id: str = "gpt-4.1-mini",
32
+ **kwargs,
33
+ ):
34
+ self.client = client or OpenAI()
35
+ self.model_id = model_id
36
+
37
+ super().__init__(**kwargs)
38
+
39
+ def forward(
40
+ self, knowledge_area: str, environment: str, categories: str, items: str
41
+ ) -> str:
42
+ response = self.client.responses.create(
43
+ model=self.model_id,
44
+ input=[
45
+ {
46
+ "role": "user",
47
+ "content": [
48
+ {
49
+ "type": "input_text",
50
+ "text": self._prompt(
51
+ knowledge_area=knowledge_area,
52
+ context=environment,
53
+ categories=categories,
54
+ items=items,
55
+ ),
56
+ },
57
+ ],
58
+ }
59
+ ],
60
+ )
61
+ answer = response.output_text
62
+ return answer
63
+
64
+ def _prompt(
65
+ self, knowledge_area: str, context: str, categories: str, items: str
66
+ ) -> str:
67
+ return f"""\
68
+ You are {knowledge_area} classifier located in {context} context.
69
+ I will provide you a list of items and a list of categories and context in which items should be considered.
70
+
71
+ Your task is to classify the items into the categories.
72
+ Use context to determine the meaning of the items and decide if you need to classify entire item or only part of it.
73
+
74
+ Do not miss any item and do not add any item to the list of categories.
75
+ Use highest probability category for each item.
76
+ You can add category "Other" if you are not sure about the classification.
77
+
78
+ Use only considerations from from the {knowledge_area} perspective.
79
+ Explain your reasoning from {knowledge_area} perspective in {context} context and then provide final answer.
80
+ Important: Do not allow {context} influence your judgment for classification.
81
+
82
+ ITEMS: {items}
83
+ CATEGORIES: {categories}
84
+
85
+ Now provide your reasoning and finalize it with the classification in the following format:
86
+ Category 1: items list
87
+ Category 2: items list
88
+ Other (if needed): items list
89
+ """
tools/content_retriever_tool.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from docling.document_converter import DocumentConverter
3
+ from docling.chunking import HierarchicalChunker
4
+ from sentence_transformers import SentenceTransformer, util
5
+ import torch
6
+
7
+
8
+ class ContentRetrieverTool(Tool):
9
+ name = "retrieve_content"
10
+ description = """Retrieve the content of a webpage or document in markdown format. Supports PDF, DOCX, XLSX, HTML, images, and more."""
11
+ inputs = {
12
+ "url": {
13
+ "type": "string",
14
+ "description": "The URL or local path of the webpage or document to retrieve.",
15
+ },
16
+ "query": {
17
+ "type": "string",
18
+ "description": "The subject on the page you are looking for. The shorter the more relevant content is returned.",
19
+ },
20
+ }
21
+ output_type = "string"
22
+
23
+ def __init__(
24
+ self,
25
+ model_name: str | None = None,
26
+ threshold: float = 0.2,
27
+ **kwargs,
28
+ ):
29
+ self.threshold = threshold
30
+ self._document_converter = DocumentConverter()
31
+ self._model = SentenceTransformer(
32
+ model_name if model_name is not None else "all-MiniLM-L6-v2"
33
+ )
34
+ self._chunker = HierarchicalChunker()
35
+
36
+ super().__init__(**kwargs)
37
+
38
+ def forward(self, url: str, query: str) -> str:
39
+ document = self._document_converter.convert(url).document
40
+
41
+ chunks = list(self._chunker.chunk(dl_doc=document))
42
+ if len(chunks) == 0:
43
+ return "No content found."
44
+
45
+ chunks_text = [chunk.text for chunk in chunks]
46
+ chunks_with_context = [self._chunker.contextualize(chunk) for chunk in chunks]
47
+ chunks_context = [
48
+ chunks_with_context[i].replace(chunks_text[i], "").strip()
49
+ for i in range(len(chunks))
50
+ ]
51
+
52
+ chunk_embeddings = self._model.encode(chunks_text, convert_to_tensor=True)
53
+ context_embeddings = self._model.encode(chunks_context, convert_to_tensor=True)
54
+ query_embedding = self._model.encode(
55
+ [term.strip() for term in query.split(",") if term.strip()],
56
+ convert_to_tensor=True,
57
+ )
58
+
59
+ selected_indices = [] # aggregate indexes across chunks and context matches and for all queries
60
+ for embeddings in [
61
+ context_embeddings,
62
+ chunk_embeddings,
63
+ ]:
64
+ # Compute cosine similarities (returns 1D tensor)
65
+ for cos_scores in util.pytorch_cos_sim(query_embedding, embeddings):
66
+ # Convert to softmax probabilities
67
+ probabilities = torch.nn.functional.softmax(cos_scores, dim=0)
68
+ # Sort by probability descending
69
+ sorted_indices = torch.argsort(probabilities, descending=True)
70
+ # Accumulate until total probability reaches threshold
71
+
72
+ cumulative = 0.0
73
+ for i in sorted_indices:
74
+ cumulative += probabilities[i].item()
75
+ selected_indices.append(i.item())
76
+ if cumulative >= self.threshold:
77
+ break
78
+
79
+ selected_indices = list(
80
+ dict.fromkeys(selected_indices)
81
+ ) # remove duplicates and preserve order
82
+ selected_indices = selected_indices[
83
+ ::-1
84
+ ] # make most relevant items last for better focus
85
+
86
+ if len(selected_indices) == 0:
87
+ return "No content found."
88
+
89
+ return "\n\n".join([chunks_with_context[idx] for idx in selected_indices])
tools/get_attachment_tool.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ import requests
3
+ from urllib.parse import urljoin
4
+ import base64
5
+ import tempfile
6
+
7
+
8
+ class GetAttachmentTool(Tool):
9
+ name = "get_attachment"
10
+ description = """Retrieves attachment for current task in specified format."""
11
+ inputs = {
12
+ "fmt": {
13
+ "type": "string",
14
+ "description": "Format to retrieve attachment. Options are: URL (preferred), DATA_URL, LOCAL_FILE_PATH, TEXT. URL returns the URL of the file, DATA_URL returns a base64 encoded data URL, LOCAL_FILE_PATH returns a local file path to the downloaded file, and TEXT returns the content of the file as text.",
15
+ "nullable": True,
16
+ "default": "URL",
17
+ }
18
+ }
19
+ output_type = "string"
20
+
21
+ def __init__(
22
+ self,
23
+ agent_evaluation_api: str | None = None,
24
+ task_id: str | None = None,
25
+ **kwargs,
26
+ ):
27
+ self.agent_evaluation_api = (
28
+ agent_evaluation_api
29
+ if agent_evaluation_api is not None
30
+ else "https://agents-course-unit4-scoring.hf.space/"
31
+ )
32
+ self.task_id = task_id
33
+ super().__init__(**kwargs)
34
+
35
+ def attachment_for(self, task_id: str | None):
36
+ self.task_id = task_id
37
+
38
+ def forward(self, fmt: str = "URL") -> str:
39
+ fmt = fmt.upper()
40
+ assert fmt in ["URL", "DATA_URL", "LOCAL_FILE_PATH", "TEXT"]
41
+
42
+ if not self.task_id:
43
+ return ""
44
+
45
+ file_url = urljoin(self.agent_evaluation_api, f"files/{self.task_id}")
46
+ if fmt == "URL":
47
+ return file_url
48
+
49
+ response = requests.get(
50
+ file_url,
51
+ headers={
52
+ "Content-Type": "application/json",
53
+ "Accept": "application/json",
54
+ },
55
+ )
56
+ if 400 <= response.status_code < 500:
57
+ return ""
58
+
59
+ response.raise_for_status()
60
+ mime = response.headers.get("content-type", "text/plain")
61
+ if fmt == "TEXT":
62
+ if mime.startswith("text/"):
63
+ return response.text
64
+ else:
65
+ raise ValueError(
66
+ f"Content of file type {mime} cannot be retrieved as TEXT."
67
+ )
68
+ elif fmt == "DATA_URL":
69
+ return f"data:{mime};base64,{base64.b64encode(response.content).decode('utf-8')}"
70
+ elif fmt == "LOCAL_FILE_PATH":
71
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
72
+ tmp_file.write(response.content)
73
+ return tmp_file.name
74
+ else:
75
+ raise ValueError(
76
+ f"Unsupported format: {fmt}. Supported formats are URL, DATA_URL, LOCAL_FILE_PATH, and TEXT."
77
+ )
tools/google_search_tools.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from googleapiclient.discovery import build
3
+ import os
4
+
5
+
6
+ class GoogleSearchTool(Tool):
7
+ name = "web_search"
8
+ description = """Performs a google web search for query then returns top search results in markdown format."""
9
+ inputs = {
10
+ "query": {
11
+ "type": "string",
12
+ "description": "The query to perform search.",
13
+ },
14
+ }
15
+ output_type = "string"
16
+
17
+ skip_forward_signature_validation = True
18
+
19
+ def __init__(
20
+ self,
21
+ api_key: str | None = None,
22
+ search_engine_id: str | None = None,
23
+ num_results: int = 10,
24
+ **kwargs,
25
+ ):
26
+ api_key = api_key if api_key is not None else os.getenv("GOOGLE_SEARCH_API_KEY")
27
+ if not api_key:
28
+ raise ValueError(
29
+ "Please set the GOOGLE_SEARCH_API_KEY environment variable."
30
+ )
31
+ search_engine_id = (
32
+ search_engine_id
33
+ if search_engine_id is not None
34
+ else os.getenv("GOOGLE_SEARCH_ENGINE_ID")
35
+ )
36
+ if not search_engine_id:
37
+ raise ValueError(
38
+ "Please set the GOOGLE_SEARCH_ENGINE_ID environment variable."
39
+ )
40
+
41
+ self.cse = build("customsearch", "v1", developerKey=api_key).cse()
42
+ self.cx = search_engine_id
43
+ self.num = num_results
44
+ super().__init__(**kwargs)
45
+
46
+ def _collect_params(self) -> dict:
47
+ return {}
48
+
49
+ def forward(self, query: str, *args, **kwargs) -> str:
50
+ params = {
51
+ "q": query,
52
+ "cx": self.cx,
53
+ "fields": "items(title,link,snippet)",
54
+ "num": self.num,
55
+ }
56
+
57
+ params = params | self._collect_params(*args, **kwargs)
58
+
59
+ response = self.cse.list(**params).execute()
60
+ if "items" not in response:
61
+ return "No results found."
62
+
63
+ result = "\n\n".join(
64
+ [
65
+ f"[{item['title']}]({item['link']})\n{item['snippet']}"
66
+ for item in response["items"]
67
+ ]
68
+ )
69
+ return result
70
+
71
+
72
+ class GoogleSiteSearchTool(GoogleSearchTool):
73
+ name = "site_search"
74
+ description = """Performs a google search within the website for query then returns top search results in markdown format."""
75
+ inputs = {
76
+ "query": {
77
+ "type": "string",
78
+ "description": "The query to perform search.",
79
+ },
80
+ "site": {
81
+ "type": "string",
82
+ "description": "The domain of the site on which to search.",
83
+ },
84
+ }
85
+
86
+ def _collect_params(self, site: str) -> dict:
87
+ return {
88
+ "siteSearch": site,
89
+ "siteSearchFilter": "i",
90
+ }
tools/speech_recognition_tool.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ import torch
3
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, logging
4
+ import warnings
5
+
6
+
7
+ class SpeechRecognitionTool(Tool):
8
+ name = "speech_to_text"
9
+ description = """Transcribes speech from audio."""
10
+
11
+ inputs = {
12
+ "audio": {
13
+ "type": "string",
14
+ "description": "Path to the audio file to transcribe.",
15
+ },
16
+ "with_time_markers": {
17
+ "type": "boolean",
18
+ "description": "Whether to include timestamps in the transcription output. Each timestamp appears on its own line in the format [float, float], indicating the number of seconds elapsed from the start of the audio.",
19
+ "nullable": True,
20
+ "default": False,
21
+ },
22
+ }
23
+ output_type = "string"
24
+
25
+ chunk_length_s = 30
26
+
27
+ def __new__(cls, *args, **kwargs):
28
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
29
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
30
+
31
+ model_id = "openai/whisper-large-v3-turbo"
32
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
33
+ model_id,
34
+ torch_dtype=torch_dtype,
35
+ low_cpu_mem_usage=True,
36
+ use_safetensors=True,
37
+ )
38
+ model.to(device)
39
+ processor = AutoProcessor.from_pretrained(model_id)
40
+
41
+ logging.set_verbosity_error()
42
+ warnings.filterwarnings(
43
+ "ignore",
44
+ category=FutureWarning,
45
+ message=r".*The input name `inputs` is deprecated.*",
46
+ )
47
+ cls.pipe = pipeline(
48
+ "automatic-speech-recognition",
49
+ model=model,
50
+ tokenizer=processor.tokenizer,
51
+ feature_extractor=processor.feature_extractor,
52
+ torch_dtype=torch_dtype,
53
+ device=device,
54
+ chunk_length_s=cls.chunk_length_s,
55
+ return_timestamps=True,
56
+ )
57
+
58
+ return super().__new__(cls, *args, **kwargs)
59
+
60
+ def forward(self, audio: str, with_time_markers: bool = False) -> str:
61
+ """
62
+ Transcribes speech from audio.
63
+
64
+ Args:
65
+ audio (str): Path to the audio file to transcribe.
66
+ with_time_markers (bool): Whether to include timestamps in the transcription output. Each timestamp appears on its own line in the format [float], indicating the number of seconds elapsed from the start of the audio.
67
+
68
+ Returns:
69
+ str: The transcribed text.
70
+ """
71
+ result = self.pipe(audio)
72
+ if not with_time_markers:
73
+ return result["text"].strip()
74
+
75
+ txt = ""
76
+ for chunk in self._normalize_chunks(result["chunks"]):
77
+ txt += f"[{chunk['start']:.2f}]\n{chunk['text']}\n[{chunk['end']:.2f}]\n"
78
+ return txt.strip()
79
+
80
+ def transcribe(self, audio, **kwargs):
81
+ result = self.pipe(audio, **kwargs)
82
+ return self._normalize_chunks(result["chunks"])
83
+
84
+ def _normalize_chunks(self, chunks):
85
+ chunk_length_s = self.chunk_length_s
86
+ absolute_offset = 0.0
87
+ chunk_offset = 0.0
88
+ normalized = []
89
+
90
+ for chunk in chunks:
91
+ timestamp_start = chunk["timestamp"][0]
92
+ timestamp_end = chunk["timestamp"][1]
93
+ if timestamp_start < chunk_offset:
94
+ absolute_offset += chunk_length_s
95
+ chunk_offset = timestamp_start
96
+ absolute_start = absolute_offset + timestamp_start
97
+
98
+ if timestamp_end < timestamp_start:
99
+ absolute_offset += chunk_length_s
100
+ absolute_end = absolute_offset + timestamp_end
101
+ chunk_offset = timestamp_end
102
+
103
+ chunk_text = chunk["text"].strip()
104
+ if chunk_text:
105
+ normalized.append(
106
+ {
107
+ "start": absolute_start,
108
+ "end": absolute_end,
109
+ "text": chunk_text,
110
+ }
111
+ )
112
+
113
+ return normalized
tools/youtube_video_tool.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ from openai import OpenAI
3
+ from .speech_recognition_tool import SpeechRecognitionTool
4
+ from io import BytesIO
5
+ import yt_dlp
6
+ import av
7
+ import torchaudio
8
+ import subprocess
9
+ import requests
10
+ import base64
11
+
12
+
13
+ class YoutubeVideoTool(Tool):
14
+ name = "youtube_video"
15
+ description = """Process the video and return the requested information from it."""
16
+ inputs = {
17
+ "url": {
18
+ "type": "string",
19
+ "description": "The URL of the YouTube video.",
20
+ },
21
+ "query": {
22
+ "type": "string",
23
+ "description": "The question to answer.",
24
+ },
25
+ }
26
+ output_type = "string"
27
+
28
+ def __init__(
29
+ self,
30
+ video_quality: int = 360,
31
+ frames_interval: int | float | None = 2,
32
+ chunk_duration: int | float | None = 20,
33
+ speech_recognition_tool: SpeechRecognitionTool | None = None,
34
+ client: OpenAI | None = None,
35
+ model_id: str = "gpt-4.1-mini",
36
+ debug: bool = False,
37
+ **kwargs,
38
+ ):
39
+ self.video_quality = video_quality
40
+ self.speech_recognition_tool = speech_recognition_tool
41
+ self.frames_interval = frames_interval
42
+ self.chunk_duration = chunk_duration
43
+
44
+ self.client = client or OpenAI()
45
+ self.model_id = model_id
46
+
47
+ self.debug = debug
48
+
49
+ super().__init__(**kwargs)
50
+
51
+ def forward(self, url: str, query: str):
52
+ """
53
+ Process the video and return the requested information.
54
+ Args:
55
+ url (str): The URL of the YouTube video.
56
+ query (str): The question to answer.
57
+ Returns:
58
+ str: Answer to the query.
59
+ """
60
+ answer = ""
61
+ for chunk in self._split_video_into_chunks(url):
62
+ prompt = self._prompt(
63
+ chunk,
64
+ query,
65
+ answer,
66
+ )
67
+ response = self.client.responses.create(
68
+ model="gpt-4.1-mini",
69
+ input=[
70
+ {
71
+ "role": "user",
72
+ "content": [
73
+ {
74
+ "type": "input_text",
75
+ "text": prompt,
76
+ },
77
+ *[
78
+ {
79
+ "type": "input_image",
80
+ "image_url": f"data:image/jpeg;base64,{frame}",
81
+ }
82
+ for frame in self._base64_frames(chunk["frames"])
83
+ ],
84
+ ],
85
+ }
86
+ ],
87
+ )
88
+ answer = response.output_text
89
+ if self.debug:
90
+ print(
91
+ f"CHUNK {chunk['start']} - {chunk['end']}:\n\n{prompt}\n\nANSWER:\n{answer}"
92
+ )
93
+
94
+ if answer.strip() == "I need to keep watching":
95
+ answer = ""
96
+ return answer
97
+
98
+ def _prompt(self, chunk, query, aggregated_answer):
99
+ prompt = [
100
+ f"""\
101
+ These are some frames of a video that I want to upload.
102
+ I will ask a question about the entire video, but I will only last part of it.
103
+ Aggregate answer about the entire video, use information about previous parts but do not reference the previous parts in the answer directly.
104
+
105
+ Ground your answer based on video title, description, captions, vide frames or answer from previous parts.
106
+ If no evidences presented just say "I need to keep watching".
107
+
108
+ VIDEO TITLE:
109
+ {chunk["title"]}
110
+
111
+ VIDEO DESCRIPTION:
112
+ {chunk["description"]}
113
+
114
+ FRAMES SUBTITLES:
115
+ {chunk["captions"]}"""
116
+ ]
117
+
118
+ if aggregated_answer:
119
+ prompt.append(f"""\
120
+ Here is the answer to the same question based on the previous video parts:
121
+
122
+ BASED ON PREVIOUS PARTS:
123
+ {aggregated_answer}""")
124
+
125
+ prompt.append(f"""\
126
+
127
+ QUESTION:
128
+ {query}""")
129
+
130
+ return "\n\n".join(prompt)
131
+
132
+ def _split_video_into_chunks(
133
+ self, url: str, with_captions: bool = True, with_frames: bool = True
134
+ ):
135
+ video = self._process_video(
136
+ url, with_captions=with_captions, with_frames=with_frames
137
+ )
138
+ video_duration = video["duration"]
139
+ chunk_duration = self.chunk_duration or video_duration
140
+
141
+ chunk_start = 0.0
142
+ while chunk_start < video_duration:
143
+ chunk_end = min(chunk_start + chunk_duration, video_duration)
144
+ chunk = self._get_video_chunk(video, chunk_start, chunk_end)
145
+ yield chunk
146
+ chunk_start += chunk_duration
147
+
148
+ def _get_video_chunk(self, video, start, end):
149
+ chunk_captions = [
150
+ c for c in video["captions"] if c["start"] <= end and c["end"] >= start
151
+ ]
152
+ chunk_frames = [
153
+ f
154
+ for f in video["frames"]
155
+ if f["timestamp"] >= start and f["timestamp"] <= end
156
+ ]
157
+
158
+ return {
159
+ "title": video["title"],
160
+ "description": video["description"],
161
+ "start": start,
162
+ "end": end,
163
+ "captions": "\n".join([c["text"] for c in chunk_captions]),
164
+ "frames": chunk_frames,
165
+ }
166
+
167
+ def _process_video(
168
+ self, url: str, with_captions: bool = True, with_frames: bool = True
169
+ ):
170
+ lang = "en"
171
+ info = self._get_video_info(url, lang)
172
+
173
+ if with_captions:
174
+ captions = self._extract_captions(
175
+ lang, info.get("subtitles", {}), info.get("automatic_captions", {})
176
+ )
177
+ if not captions and self.speech_recognition_tool:
178
+ audio_url = self._select_audio_format(info["formats"])
179
+ audio = self._capture_audio(audio_url)
180
+ waveform, sample_rate = torchaudio.load(audio)
181
+ assert sample_rate == 16000
182
+ waveform_np = waveform.squeeze().numpy()
183
+ captions = self.speech_recognition_tool.transcribe(waveform_np)
184
+ else:
185
+ captions = []
186
+
187
+ if with_frames:
188
+ video_url = self._select_video_format(info["formats"], 360)["url"]
189
+ frames = self._capture_video_frames(video_url, self.frames_interval)
190
+ else:
191
+ frames = []
192
+
193
+ return {
194
+ "id": info["id"],
195
+ "title": info["title"],
196
+ "description": info["description"],
197
+ "duration": info["duration"],
198
+ "captions": captions,
199
+ "frames": frames,
200
+ }
201
+
202
+ def _get_video_info(self, url: str, lang: str):
203
+ ydl_opts = {
204
+ "quiet": True,
205
+ "skip_download": True,
206
+ "format": "bestvideo[ext=mp4][height<=360]+bestaudio[ext=m4a]/best[height<=360]",
207
+ "forceurl": True,
208
+ "noplaylist": True,
209
+ "writesubtitles": True,
210
+ "writeautomaticsub": True,
211
+ "subtitlesformat": "vtt",
212
+ "subtitleslangs": [lang],
213
+ }
214
+
215
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
216
+ info = ydl.extract_info(url, download=False)
217
+
218
+ return info
219
+
220
+ def _extract_captions(self, lang, subtitles, auto_captions):
221
+ caption_tracks = subtitles.get(lang) or auto_captions.get(lang) or []
222
+
223
+ structured_captions = []
224
+
225
+ srt_track = next(
226
+ (track for track in caption_tracks if track["ext"] == "srt"), None
227
+ )
228
+ vtt_track = next(
229
+ (track for track in caption_tracks if track["ext"] == "vtt"), None
230
+ )
231
+
232
+ if srt_track:
233
+ import pysrt
234
+
235
+ response = requests.get(srt_track["url"])
236
+ response.raise_for_status()
237
+ srt_data = response.content.decode("utf-8")
238
+
239
+ def to_sec(t):
240
+ return (
241
+ t.hours * 3600 + t.minutes * 60 + t.seconds + t.milliseconds / 1000
242
+ )
243
+
244
+ structured_captions = [
245
+ {
246
+ "start": to_sec(sub.start),
247
+ "end": to_sec(sub.end),
248
+ "text": sub.text.strip(),
249
+ }
250
+ for sub in pysrt.from_str(srt_data)
251
+ ]
252
+ if vtt_track:
253
+ import webvtt
254
+ from io import StringIO
255
+
256
+ response = requests.get(vtt_track["url"])
257
+ response.raise_for_status()
258
+ vtt_data = response.text
259
+
260
+ vtt_file = StringIO(vtt_data)
261
+
262
+ def to_sec(t):
263
+ """Convert 'HH:MM:SS.mmm' to float seconds"""
264
+ h, m, s = t.split(":")
265
+ s, ms = s.split(".")
266
+ return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000
267
+
268
+ for caption in webvtt.read_buffer(vtt_file):
269
+ structured_captions.append(
270
+ {
271
+ "start": to_sec(caption.start),
272
+ "end": to_sec(caption.end),
273
+ "text": caption.text.strip(),
274
+ }
275
+ )
276
+ return structured_captions
277
+
278
+ def _select_video_format(self, formats, video_quality):
279
+ video_format = next(
280
+ f
281
+ for f in formats
282
+ if f.get("vcodec") != "none" and f.get("height") == video_quality
283
+ )
284
+ return video_format
285
+
286
+ def _capture_video_frames(self, video_url, capture_interval_sec=None):
287
+ ffmpeg_cmd = [
288
+ "ffmpeg",
289
+ "-i",
290
+ video_url,
291
+ "-f",
292
+ "matroska", # container format
293
+ "-",
294
+ ]
295
+
296
+ process = subprocess.Popen(
297
+ ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL
298
+ )
299
+
300
+ container = av.open(process.stdout)
301
+ stream = container.streams.video[0]
302
+ time_base = stream.time_base
303
+
304
+ frames = []
305
+ next_capture_time = 0
306
+ for frame in container.decode(stream):
307
+ if frame.pts is None:
308
+ continue
309
+
310
+ timestamp = float(frame.pts * time_base)
311
+ if capture_interval_sec is None or timestamp >= next_capture_time:
312
+ frames.append(
313
+ {
314
+ "timestamp": timestamp,
315
+ "image": frame.to_image(), # PIL image
316
+ }
317
+ )
318
+ if capture_interval_sec is not None:
319
+ next_capture_time += capture_interval_sec
320
+
321
+ process.terminate()
322
+ return frames
323
+
324
+ def _base64_frames(self, frames):
325
+ base64_frames = []
326
+ for f in frames:
327
+ buffered = BytesIO()
328
+ f["image"].save(buffered, format="JPEG")
329
+ encoded = base64.b64encode(buffered.getvalue()).decode("utf-8")
330
+ base64_frames.append(encoded)
331
+ return base64_frames
332
+
333
+ def _select_audio_format(self, formats):
334
+ audio_formats = [
335
+ f
336
+ for f in formats
337
+ if f.get("vcodec") == "none"
338
+ and f.get("acodec")
339
+ and f.get("acodec") != "none"
340
+ ]
341
+
342
+ if not audio_formats:
343
+ raise ValueError("No valid audio-only formats found.")
344
+
345
+ # Prefer m4a > webm, highest abr first
346
+ preferred_exts = ["m4a", "webm"]
347
+
348
+ def sort_key(f):
349
+ ext_score = (
350
+ preferred_exts.index(f["ext"]) if f["ext"] in preferred_exts else 99
351
+ )
352
+ abr = f.get("abr") or 0
353
+ return (ext_score, -abr)
354
+
355
+ audio_formats.sort(key=sort_key)
356
+ return audio_formats[0]["url"]
357
+
358
+ def _capture_audio(self, audio_url) -> BytesIO:
359
+ audio_buffer = BytesIO()
360
+ ffmpeg_audio_cmd = [
361
+ "ffmpeg",
362
+ "-i",
363
+ audio_url,
364
+ "-f",
365
+ "wav",
366
+ "-acodec",
367
+ "pcm_s16le", # Whisper prefers PCM
368
+ "-ac",
369
+ "1", # Mono
370
+ "-ar",
371
+ "16000", # 16kHz for Whisper
372
+ "-",
373
+ ]
374
+
375
+ result = subprocess.run(
376
+ ffmpeg_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
377
+ )
378
+ if result.returncode != 0:
379
+ raise RuntimeError("ffmpeg failed:\n" + result.stderr.decode())
380
+
381
+ audio_buffer = BytesIO(result.stdout)
382
+ audio_buffer.seek(0)
383
+ return audio_buffer