wt002 commited on
Commit
b403954
·
verified ·
1 Parent(s): b52644e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +328 -167
app.py CHANGED
@@ -1,209 +1,370 @@
1
  import os
2
- import gradio as gr
3
  import requests
4
  import pandas as pd
5
- from smolagents import CodeAgent, OpenAIServerModel, DuckDuckGoSearchTool, VisitWebpageTool, tool, \
6
- FinalAnswerTool, PythonInterpreterTool, SpeechToTextTool, ToolCallingAgent
7
- import yaml
8
- import importlib
9
- from io import BytesIO
10
- import tempfile
11
  import base64
12
- from youtube_transcript_api import YouTubeTranscriptApi
13
- from youtube_transcript_api._errors import TranscriptsDisabled, NoTranscriptFound, VideoUnavailable
14
- from urllib.parse import urlparse, parse_qs
15
- import json
16
- import whisper
17
- import re
18
-
19
 
 
 
20
 
21
- # (Keep Constants as is)
22
  # --- Constants ---
23
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
24
 
 
 
 
 
 
 
 
 
 
25
 
26
  @tool
27
- def transcribe_audio_file(file_path: str) -> str:
28
  """
29
- Transcribes a local MP3 audio file using Whisper.
 
30
  Args:
31
- file_path: Full path to the .mp3 audio file.
 
32
  Returns:
33
- A JSON-formatted string containing either the transcript or an error message.
34
- {
35
- "success": true,
36
- "transcript": [
37
- {"start": 0.0, "end": 5.2, "text": "Hello and welcome"},
38
- ...
39
- ]
40
- }
41
- OR
42
- {
43
- "success": false,
44
- "error": "Reason why transcription failed"
45
- }
46
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  try:
48
- if not os.path.exists(file_path):
49
- return json.dumps({"success": False, "error": "File does not exist."})
50
-
51
- if not file_path.lower().endswith(".mp3"):
52
- return json.dumps({"success": False, "error": "Invalid file type. Only MP3 files are supported."})
53
-
54
- model = whisper.load_model("base") # You can use 'tiny', 'base', 'small', 'medium', or 'large'
55
- result = model.transcribe(file_path, verbose=False, word_timestamps=False)
 
 
 
 
 
 
 
 
 
56
 
57
- transcript_data = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  {
59
- "start": segment["start"],
60
- "end": segment["end"],
61
- "text": segment["text"].strip()
 
 
 
 
 
 
 
 
 
 
62
  }
63
- for segment in result["segments"]
64
- ]
65
-
66
- return json.dumps({"success": True, "transcript": transcript_data})
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  except Exception as e:
69
- return json.dumps({"success": False, "error": str(e)})
70
 
 
 
 
 
 
 
 
 
71
 
72
  @tool
73
- def get_youtube_transcript(video_url: str) -> str:
74
  """
75
- Retrieves the transcript from a YouTube video URL, including timestamps.
76
- This tool fetches the English transcript for a given YouTube video. Automatically generated subtitles
77
- are also supported. The result includes each snippet's start time, duration, and text.
78
  Args:
79
- video_url: The full URL of the YouTube video (e.g., https://www.youtube.com/watch?v=12345)
 
80
  Returns:
81
- A JSON-formatted string containing either the transcript with timestamps or an error message.
82
- {
83
- "success": true,
84
- "transcript": [
85
- {"start": 0.0, "duration": 1.54, "text": "Hey there"},
86
- {"start": 1.54, "duration": 4.16, "text": "how are you"},
87
- ...
88
- ]
89
- }
90
- OR
91
- {
92
- "success": false,
93
- "error": "Reason why the transcript could not be retrieved"
94
- }
95
  """
96
  try:
97
- # Extract video ID from URL
98
- parsed_url = urlparse(video_url)
99
- query_params = parse_qs(parsed_url.query)
100
- video_id = query_params.get("v", [None])[0]
101
-
102
- if not video_id:
103
- return json.dumps({"success": False, "error": "Invalid YouTube URL. Could not extract video ID."})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- fetched_transcript = YouTubeTranscriptApi().fetch(video_id)
106
- transcript_data = [
107
- {
108
- "start": snippet.start,
109
- "duration": snippet.duration,
110
- "text": snippet.text
111
- }
112
- for snippet in fetched_transcript
113
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- return json.dumps({"success": True, "transcript": transcript_data})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- except VideoUnavailable:
118
- return json.dumps({"success": False, "error": "The video is unavailable."})
119
- except TranscriptsDisabled:
120
- return json.dumps({"success": False, "error": "Transcripts are disabled for this video."})
121
- except NoTranscriptFound:
122
- return json.dumps({"success": False, "error": "No transcript found for this video."})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  except Exception as e:
124
- return json.dumps({"success": False, "error": str(e)})
125
 
126
- # --- Basic Agent Definition ---
127
- # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
128
- class BasicAgent:
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def __init__(self):
131
- model = OpenAIServerModel(api_key=os.environ.get("OPENAI_API_KEY"), model_id="gpt-4o")
132
-
133
- self.code_agent = CodeAgent(
134
- tools=[PythonInterpreterTool(), DuckDuckGoSearchTool(), VisitWebpageTool(), transcribe_audio_file,
135
- get_youtube_transcript,
136
- FinalAnswerTool()],
137
- model=model,
138
- max_steps=20,
139
- name="hf_agent_course_final_assignment_solver",
140
- prompt_templates=yaml.safe_load(
141
- importlib.resources.files("prompts").joinpath("code_agent.yaml").read_text()
142
- )
143
-
144
- )
145
  print("BasicAgent initialized.")
 
 
146
 
147
- def __call__(self, task_id: str, question: str, file_name: str) -> str:
148
- if file_name:
149
- question = self.enrich_question_with_associated_file_details(task_id, question, file_name)
150
-
151
- final_result = self.code_agent.run(question)
152
-
153
- # Extract text after "FINAL ANSWER:" (case-insensitive, and trims whitespace)
154
- match = re.search(r'final answer:\s*(.*)', str(final_result), re.IGNORECASE | re.DOTALL)
155
- if match:
156
- return match.group(1).strip()
157
-
158
- # Fallback in case the pattern is not found
159
- return str(final_result).strip()
160
-
161
- def enrich_question_with_associated_file_details(self, task_id:str, question: str, file_name: str) -> str:
162
- api_url = DEFAULT_API_URL
163
- get_associated_files_url = f"{api_url}/files/{task_id}"
164
- response = requests.get(get_associated_files_url, timeout=15)
165
- response.raise_for_status()
166
-
167
- if file_name.endswith(".mp3"):
168
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
169
- tmp_file.write(response.content)
170
- file_path = tmp_file.name
171
- return question + "\n\nMentioned .mp3 file local path is: " + file_path
172
- elif file_name.endswith(".py"):
173
- file_content = response.text
174
- return question + "\n\nBelow is mentioned Python file:\n\n```python\n" + file_content + "\n```\n"
175
- elif file_name.endswith(".xlsx"):
176
- xlsx_io = BytesIO(response.content)
177
- df = pd.read_excel(xlsx_io)
178
- file_content = df.to_csv(index=False)
179
- return question + "\n\nBelow is mentioned excel file in CSV format:\n\n```csv\n" + file_content + "\n```\n"
180
- elif file_name.endswith(".png"):
181
- base64_str = base64.b64encode(response.content).decode('utf-8')
182
- return question + "\n\nBelow is the .png image in base64 format:\n\n```base64\n" + base64_str + "\n```\n"
183
-
184
-
185
- def enrich_question_with_associated_file_details(self, task_id:str, question: str, file_name: str) -> str:
186
- api_url = DEFAULT_API_URL
187
- get_associated_files_url = f"{api_url}/files/{task_id}"
188
- response = requests.get(get_associated_files_url, timeout=15)
189
- response.raise_for_status()
190
 
191
- if file_name.endswith(".mp3"):
192
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
193
- tmp_file.write(response.content)
194
- file_path = tmp_file.name
195
- return question + "\n\nMentioned .mp3 file local path is: " + file_path
196
- elif file_name.endswith(".py"):
197
- file_content = response.text
198
- return question + "\n\nBelow is mentioned Python file:\n\n```python\n" + file_content + "\n```\n"
199
- elif file_name.endswith(".xlsx"):
200
- xlsx_io = BytesIO(response.content)
201
- df = pd.read_excel(xlsx_io)
202
- file_content = df.to_csv(index=False)
203
- return question + "\n\nBelow is mentioned excel file in CSV format:\n\n```csv\n" + file_content + "\n```\n"
204
- elif file_name.endswith(".png"):
205
- base64_str = base64.b64encode(response.content).decode('utf-8')
206
- return question + "\n\nBelow is the .png image in base64 format:\n\n```base64\n" + base64_str + "\n```\n"
207
 
208
 
209
  def run_and_submit_all( profile: gr.OAuthProfile | None):
 
1
  import os
2
+ from dotenv import load_dotenv
3
  import requests
4
  import pandas as pd
 
 
 
 
 
 
5
  import base64
6
+ import mimetypes
7
+ import tempfile
8
+ from smolagents import CodeAgent, OpenAIServerModel, tool
9
+ from dotenv import load_dotenv
10
+ from openai import OpenAI
 
 
11
 
12
+ # Load environment variables
13
+ load_dotenv()
14
 
 
15
  # --- Constants ---
16
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
17
 
18
+ # Initialize the OpenAI model using environment variable for API key
19
+ model = OpenAIServerModel(
20
+ model_id="o4-mini-2025-04-16",
21
+ api_base="https://api.openai.com/v1",
22
+ api_key=os.getenv("openai"),
23
+ )
24
+
25
+ # Initialize OpenAI client
26
+ openAiClient = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
27
 
28
  @tool
29
+ def tavily_search(query: str) -> str:
30
  """
31
+ Perform a search using the Tavily API.
32
+
33
  Args:
34
+ query: The search query string
35
+
36
  Returns:
37
+ A string containing the search results
 
 
 
 
 
 
 
 
 
 
 
 
38
  """
39
+ api_key = os.getenv("TAVILY_API_KEY")
40
+ if not api_key:
41
+ return "Error: TAVILY_API_KEY environment variable is not set"
42
+
43
+ api_url = "https://api.tavily.com/search"
44
+
45
+ headers = {
46
+ "Content-Type": "application/json",
47
+ }
48
+
49
+ payload = {
50
+ "api_key": api_key,
51
+ "query": query,
52
+ "search_depth": "advanced",
53
+ "include_answer": True,
54
+ "include_raw_content": False,
55
+ "max_results": 5
56
+ }
57
+
58
  try:
59
+ response = requests.post(api_url, headers=headers, json=payload)
60
+ response.raise_for_status()
61
+ data = response.json()
62
+
63
+ # Extract the answer and results
64
+ result = []
65
+ if "answer" in data:
66
+ result.append(f"Answer: {data['answer']}")
67
+
68
+ if "results" in data:
69
+ result.append("\nSources:")
70
+ for i, item in enumerate(data["results"], 1):
71
+ result.append(f"{i}. {item.get('title', 'No title')}: {item.get('url', 'No URL')}")
72
+
73
+ return "\n".join(result)
74
+ except Exception as e:
75
+ return f"Error performing Tavily search: {str(e)}"
76
 
77
+ @tool
78
+ def analyze_image(image_url: str) -> str:
79
+ """
80
+ Analyze an image using OpenAI's vision model and return a description.
81
+
82
+ Args:
83
+ image_url: URL of the image to analyze
84
+
85
+ Returns:
86
+ A detailed description of the image
87
+ """
88
+ api_key = os.getenv("OPENAI_API_KEY")
89
+ if not api_key:
90
+ return "Error: OpenAI API key not set in environment variables"
91
+
92
+ # Download the image
93
+ try:
94
+ response = requests.get(image_url)
95
+ response.raise_for_status()
96
+ image_data = response.content
97
+ base64_image = base64.b64encode(image_data).decode('utf-8')
98
+ except Exception as e:
99
+ return f"Error downloading image: {str(e)}"
100
+
101
+ # Call OpenAI API
102
+ api_url = "https://api.openai.com/v1/chat/completions"
103
+ headers = {
104
+ "Content-Type": "application/json",
105
+ "Authorization": f"Bearer {api_key}"
106
+ }
107
+
108
+ payload = {
109
+ "model": "gpt-4.1-2025-04-14",
110
+ "messages": [
111
  {
112
+ "role": "user",
113
+ "content": [
114
+ {
115
+ "type": "text",
116
+ "text": "Describe this image in detail. Include any text, objects, people, actions, and overall context."
117
+ },
118
+ {
119
+ "type": "image_url",
120
+ "image_url": {
121
+ "url": f"data:image/jpeg;base64,{base64_image}"
122
+ }
123
+ }
124
+ ]
125
  }
126
+ ],
127
+ "max_tokens": 500
128
+ }
129
+
130
+ try:
131
+ response = requests.post(api_url, headers=headers, json=payload)
132
+ response.raise_for_status()
133
+ data = response.json()
134
+
135
+ if "choices" in data and len(data["choices"]) > 0:
136
+ return data["choices"][0]["message"]["content"]
137
+ else:
138
+ return "No description generated"
139
+ except Exception as e:
140
+ return f"Error analyzing image: {str(e)}"
141
 
142
+ @tool
143
+ def analyze_sound(audio_url: str) -> str:
144
+ """
145
+ Transcribe an audio file using OpenAI's Whisper model.
146
+
147
+ Args:
148
+ audio_url: the url of the audio
149
+
150
+ Returns:
151
+ A transcription of the audio content
152
+ """
153
+ api_key = os.getenv("OPENAI_API_KEY")
154
+ if not api_key:
155
+ return "Error: OpenAI API key not set in environment variables"
156
+
157
+ # Download the audio file
158
+ try:
159
+ response = requests.get(audio_url)
160
+ response.raise_for_status()
161
+
162
+ import tempfile
163
+ with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as temp_file:
164
+ temp_file.write(response.content)
165
+ temp_file_path = temp_file.name
166
+
167
+ audio_file= open(temp_file_path, "rb")
168
+
169
  except Exception as e:
170
+ return f"Error downloading audio: {str(e)}"
171
 
172
+ try:
173
+ transcription = openAiClient.audio.transcriptions.create(
174
+ model="gpt-4o-transcribe",
175
+ file=audio_file
176
+ )
177
+ return transcription.text
178
+ except Exception as e:
179
+ return f"Error transcribing audio: {str(e)}"
180
 
181
  @tool
182
+ def analyze_excel(excel_url: str) -> str:
183
  """
184
+ Process an Excel file and convert it to a text-based format.
185
+
 
186
  Args:
187
+ excel_url: URL of the Excel file to analyze
188
+
189
  Returns:
190
+ A text representation of the Excel data
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  """
192
  try:
193
+ # Download the Excel file
194
+ response = requests.get(excel_url)
195
+ response.raise_for_status()
196
+
197
+ # Save to a temporary file
198
+ with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as temp_file:
199
+ temp_file.write(response.content)
200
+ temp_file_path = temp_file.name
201
+
202
+ # Read the Excel file
203
+ df = pd.read_excel(temp_file_path)
204
+
205
+ # Convert to a text representation
206
+ result = []
207
+
208
+ # Add sheet information
209
+ result.append(f"Excel file with {len(df)} rows and {len(df.columns)} columns")
210
+
211
+ # Add column names
212
+ result.append("\nColumns:")
213
+ for i, col in enumerate(df.columns, 1):
214
+ result.append(f"{i}. {col}")
215
+
216
+ # Add data summary
217
+ result.append("\nData Summary:")
218
+ result.append(df.describe().to_string())
219
+
220
+ # Add first few rows as a sample
221
+ result.append("\nFirst 5 rows:")
222
+ result.append(df.head().to_string())
223
+
224
+ # Clean up
225
+ os.unlink(temp_file_path)
226
+
227
+ return "\n".join(result)
228
+ except Exception as e:
229
+ return f"Error processing Excel file: {str(e)}"
230
 
231
+ @tool
232
+ def analyze_text(text_url: str) -> str:
233
+ """
234
+ Process a text file and return its contents.
235
+
236
+ Args:
237
+ text_url: URL of the text file to analyze
238
+
239
+ Returns:
240
+ The contents of the text file
241
+ """
242
+ try:
243
+ # Download the text file
244
+ response = requests.get(text_url)
245
+ response.raise_for_status()
246
+
247
+ # Get the text content
248
+ text_content = response.text
249
+
250
+ # For very long files, truncate with a note
251
+ if len(text_content) > 10000:
252
+ return f"Text file content (truncated to first 10000 characters):\n\n{text_content[:10000]}\n\n... [content truncated]"
253
+
254
+ return f"Text file content:\n\n{text_content}"
255
+ except Exception as e:
256
+ return f"Error processing text file: {str(e)}"
257
 
258
+ @tool
259
+ def transcribe_youtube(youtube_url: str) -> str:
260
+ """
261
+ Extract the transcript from a YouTube video.
262
+
263
+ Args:
264
+ youtube_url: URL of the YouTube video
265
+
266
+ Returns:
267
+ The transcript of the video
268
+ """
269
+ try:
270
+ # Extract video ID from URL
271
+ import re
272
+ video_id_match = re.search(r'(?:v=|\/)([0-9A-Za-z_-]{11}).*', youtube_url)
273
+ if not video_id_match:
274
+ return "Error: Invalid YouTube URL"
275
+
276
+ video_id = video_id_match.group(1)
277
+
278
+ # Use youtube_transcript_api to get the transcript
279
+ from youtube_transcript_api import YouTubeTranscriptApi
280
+
281
+ try:
282
+ transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
283
+
284
+ # Combine all transcript segments into a single text
285
+ full_transcript = ""
286
+ for segment in transcript_list:
287
+ full_transcript += segment['text'] + " "
288
+
289
+ return f"YouTube Video Transcript:\n\n{full_transcript.strip()}"
290
+ except Exception as e:
291
+ return f"Error extracting transcript: {str(e)}"
292
+ except Exception as e:
293
+ return f"Error processing YouTube video: {str(e)}"
294
 
295
+ @tool
296
+ def process_file(task_id: str, file_name: str) -> str:
297
+ """
298
+ Fetch and process a file based on task_id and file_name.
299
+ For images, it will analyze them and return a description of the image.
300
+ For audio files, it will transcribe them.
301
+ For Excel files, it will convert them to a text format.
302
+ For text files, it will return the file contents.
303
+ Other file types can be ignored for this tool.
304
+
305
+ Args:
306
+ task_id: The task ID to fetch the file for
307
+ file_name: The name of the file to process
308
+
309
+ Returns:
310
+ A description or transcription of the file content
311
+ """
312
+ if not task_id or not file_name:
313
+ return "Error: task_id and file_name are required"
314
+
315
+ # Construct the file URL
316
+ file_url = f"{DEFAULT_API_URL}/files/{task_id}"
317
+
318
+ try:
319
+ # Fetch the file
320
+ response = requests.get(file_url)
321
+ response.raise_for_status()
322
+
323
+ # Determine file type
324
+ mime_type, _ = mimetypes.guess_type(file_name)
325
+
326
+ # Process based on file type
327
+ if mime_type and mime_type.startswith('image/'):
328
+ # For images, use the analyze_image tool
329
+ return analyze_image(file_url)
330
+ elif file_name.lower().endswith('.mp3') or (mime_type and mime_type.startswith('audio/')):
331
+ # For audio files, use the analyze_sound tool
332
+ return analyze_sound(file_url)
333
+ elif file_name.lower().endswith('.xlsx') or (mime_type and mime_type == 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'):
334
+ # For Excel files, use the analyze_excel tool
335
+ return analyze_excel(file_url)
336
+ elif file_name.lower().endswith(('.txt', '.py', '.js', '.html', '.css', '.json', '.md')) or (mime_type and mime_type.startswith('text/')):
337
+ # For text files, use the analyze_text tool
338
+ return analyze_text(file_url)
339
+ else:
340
+ # For other file types, return basic information
341
+ return f"File '{file_name}' of type '{mime_type or 'unknown'}' was fetched successfully. Content processing not implemented for this file type."
342
  except Exception as e:
343
+ return f"Error processing file: {str(e)}"
344
 
 
 
 
345
 
346
+ class BasicAgent:
347
+ """
348
+ A simple agent that uses smolagents.CodeAgent with multiple specialized tools:
349
+ - Tavily search tool for web searches
350
+ - Image analysis tool for processing images
351
+ - Audio transcription tool for processing sound files
352
+ - Excel analysis tool for processing spreadsheet data
353
+ - Text file analysis tool for processing code and text files
354
+ - YouTube transcription tool for processing video content
355
+ - File processing tool for handling various file types
356
+
357
+ The CodeAgent is instantiated once and reused for each question to reduce overhead.
358
+ """
359
  def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  print("BasicAgent initialized.")
361
+ # Reuse a single CodeAgent instance for all queries
362
+ self.agent = CodeAgent(tools=[tavily_search, analyze_image, analyze_sound, analyze_excel, analyze_text, transcribe_youtube, process_file], model=model)
363
 
364
+ def __call__(self, question: str) -> str:
365
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
366
+ return self.agent.run(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
 
370
  def run_and_submit_all( profile: gr.OAuthProfile | None):