Francesco Laiti
Add .gitignore, implement AskMeAnythingAgent in agent.py, and update app.py to utilize the new agent. Include instruction prompts in YAML format.
a745ced
| import os | |
| from dotenv import load_dotenv | |
| import yaml | |
| import re | |
| import json | |
| import pandas as pd | |
| from pathlib import Path | |
| import time | |
| from agno.agent import Agent, RunResponse | |
| from agno.models.google import Gemini | |
| from agno.media import Image, Audio, Video, File | |
| from agno.tools.googlesearch import GoogleSearchTools | |
| from agno.tools.calculator import CalculatorTools | |
| URL_BASE_DOWNLOAD = "https://agents-course-unit4-scoring.hf.space/files/" | |
| load_dotenv() | |
| class AskMeAnythingAgent: | |
| def __init__(self, model_id: str = "gemini-2.0-flash", debug: bool = False): | |
| with open("instruction_prompts.yaml", "r") as file: | |
| self.instruction_prompts = yaml.safe_load(file) | |
| self.main_agent = Agent( | |
| name="Ask Me Anything", | |
| model=Gemini(id=model_id), | |
| role=self.instruction_prompts["role"], | |
| instructions=self.instruction_prompts["instructions"], | |
| goal=self.instruction_prompts["goal"], | |
| tools=[GoogleSearchTools(), CalculatorTools()], | |
| debug_mode=debug, | |
| ) | |
| def __call__(self, question: str, file_name: str = "", debug: bool = False) -> str: | |
| payload = self.manage_payload(file_name, question) | |
| response: RunResponse = self.main_agent.run(message=question, **payload) | |
| return self.parse_response(response.content) | |
| def manage_payload(self, file_name: str, question: str): | |
| if file_name != "": | |
| if file_name.endswith(".png"): | |
| return {"images": [Image(url=URL_BASE_DOWNLOAD + file_name.split(".")[0], format="png")]} | |
| elif file_name.endswith(".mp3"): | |
| return {"audio": [Audio(url=URL_BASE_DOWNLOAD + file_name.split(".")[0], format="mp3")]} | |
| elif file_name.endswith(".xlsx"): | |
| df = pd.read_excel(io=URL_BASE_DOWNLOAD + file_name.split(".")[0]) | |
| df.to_csv(f"media/{file_name.split('.')[0]}.csv", index=False) | |
| return {"files": [File(filepath=f"media/{file_name.split('.')[0]}.csv", mime_type="text/csv")]} | |
| else: | |
| return {"files": [File(url=URL_BASE_DOWNLOAD + file_name.split(".")[0])]} | |
| elif "https://www.youtube.com/" in question: | |
| url = re.search(r'https://www\.youtube\.com/watch\?v=[\w-]+', question).group(0) | |
| path_to_file = self.download_yt_video(url) | |
| return {"videos": [Video(filepath=path_to_file)]} | |
| else: | |
| return {} | |
| def download_yt_video(self, url: str): | |
| import yt_dlp | |
| import os | |
| ydl_opts = { | |
| 'format': 'best', | |
| 'outtmpl': 'media/%(id)s.%(ext)s', | |
| 'quiet': True, | |
| } | |
| os.makedirs("media", exist_ok=True) | |
| print(f"Downloading video from {url}") | |
| if not os.path.exists(f"media/{url.split('v=')[1]}.mp4"): | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| #info = ydl.extract_info(url, download=False) | |
| #final_filename = ydl.prepare_filename(info) | |
| ydl.download([url]) | |
| return Path(__file__).parent.joinpath(f"media/{url.split('v=')[1]}.mp4") | |
| def download_file(self, file_name: str): | |
| import requests | |
| response = requests.get(URL_BASE_DOWNLOAD + file_name) | |
| os.makedirs("media", exist_ok=True) | |
| if not os.path.exists(f"media/{file_name}"): | |
| with open(f"media/{file_name}", 'wb') as file: | |
| file.write(response.content) | |
| return f"media/{file_name}" | |
| def parse_response(self, response: str): | |
| field_header_pattern = re.compile(r"\[\[ ## ANSWER ## \]\]|\[\[## ANSWER ##\]\]") | |
| for line in response.splitlines(): | |
| match = field_header_pattern.search(line.strip()) | |
| if match: | |
| return line[match.end():].strip() | |
| return "invalid" | |
| if __name__ == "__main__": | |
| with open("questions.json", "r") as file: | |
| questions = json.load(file) | |
| for i, question in enumerate(questions): | |
| agent = AskMeAnythingAgent(debug=False) | |
| print(f"{i+1}. QUESTION: ", question["question"], "\nFILE: ", question["file_name"]) | |
| print(f"{i+1}. ANSWER: ", agent(question["question"], question["file_name"]), "\n\n") | |
| if (i+1) % 5 == 0: | |
| print("Sleeping for 30 seconds to avoid RPM limit from API") | |
| time.sleep(30) |