Francesco Laiti
Add .gitignore, implement AskMeAnythingAgent in agent.py, and update app.py to utilize the new agent. Include instruction prompts in YAML format.
a745ced
import os | |
from dotenv import load_dotenv | |
import yaml | |
import re | |
import json | |
import pandas as pd | |
from pathlib import Path | |
import time | |
from agno.agent import Agent, RunResponse | |
from agno.models.google import Gemini | |
from agno.media import Image, Audio, Video, File | |
from agno.tools.googlesearch import GoogleSearchTools | |
from agno.tools.calculator import CalculatorTools | |
URL_BASE_DOWNLOAD = "https://agents-course-unit4-scoring.hf.space/files/" | |
load_dotenv() | |
class AskMeAnythingAgent: | |
def __init__(self, model_id: str = "gemini-2.0-flash", debug: bool = False): | |
with open("instruction_prompts.yaml", "r") as file: | |
self.instruction_prompts = yaml.safe_load(file) | |
self.main_agent = Agent( | |
name="Ask Me Anything", | |
model=Gemini(id=model_id), | |
role=self.instruction_prompts["role"], | |
instructions=self.instruction_prompts["instructions"], | |
goal=self.instruction_prompts["goal"], | |
tools=[GoogleSearchTools(), CalculatorTools()], | |
debug_mode=debug, | |
) | |
def __call__(self, question: str, file_name: str = "", debug: bool = False) -> str: | |
payload = self.manage_payload(file_name, question) | |
response: RunResponse = self.main_agent.run(message=question, **payload) | |
return self.parse_response(response.content) | |
def manage_payload(self, file_name: str, question: str): | |
if file_name != "": | |
if file_name.endswith(".png"): | |
return {"images": [Image(url=URL_BASE_DOWNLOAD + file_name.split(".")[0], format="png")]} | |
elif file_name.endswith(".mp3"): | |
return {"audio": [Audio(url=URL_BASE_DOWNLOAD + file_name.split(".")[0], format="mp3")]} | |
elif file_name.endswith(".xlsx"): | |
df = pd.read_excel(io=URL_BASE_DOWNLOAD + file_name.split(".")[0]) | |
df.to_csv(f"media/{file_name.split('.')[0]}.csv", index=False) | |
return {"files": [File(filepath=f"media/{file_name.split('.')[0]}.csv", mime_type="text/csv")]} | |
else: | |
return {"files": [File(url=URL_BASE_DOWNLOAD + file_name.split(".")[0])]} | |
elif "https://www.youtube.com/" in question: | |
url = re.search(r'https://www\.youtube\.com/watch\?v=[\w-]+', question).group(0) | |
path_to_file = self.download_yt_video(url) | |
return {"videos": [Video(filepath=path_to_file)]} | |
else: | |
return {} | |
def download_yt_video(self, url: str): | |
import yt_dlp | |
import os | |
ydl_opts = { | |
'format': 'best', | |
'outtmpl': 'media/%(id)s.%(ext)s', | |
'quiet': True, | |
} | |
os.makedirs("media", exist_ok=True) | |
print(f"Downloading video from {url}") | |
if not os.path.exists(f"media/{url.split('v=')[1]}.mp4"): | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
#info = ydl.extract_info(url, download=False) | |
#final_filename = ydl.prepare_filename(info) | |
ydl.download([url]) | |
return Path(__file__).parent.joinpath(f"media/{url.split('v=')[1]}.mp4") | |
def download_file(self, file_name: str): | |
import requests | |
response = requests.get(URL_BASE_DOWNLOAD + file_name) | |
os.makedirs("media", exist_ok=True) | |
if not os.path.exists(f"media/{file_name}"): | |
with open(f"media/{file_name}", 'wb') as file: | |
file.write(response.content) | |
return f"media/{file_name}" | |
def parse_response(self, response: str): | |
field_header_pattern = re.compile(r"\[\[ ## ANSWER ## \]\]|\[\[## ANSWER ##\]\]") | |
for line in response.splitlines(): | |
match = field_header_pattern.search(line.strip()) | |
if match: | |
return line[match.end():].strip() | |
return "invalid" | |
if __name__ == "__main__": | |
with open("questions.json", "r") as file: | |
questions = json.load(file) | |
for i, question in enumerate(questions): | |
agent = AskMeAnythingAgent(debug=False) | |
print(f"{i+1}. QUESTION: ", question["question"], "\nFILE: ", question["file_name"]) | |
print(f"{i+1}. ANSWER: ", agent(question["question"], question["file_name"]), "\n\n") | |
if (i+1) % 5 == 0: | |
print("Sleeping for 30 seconds to avoid RPM limit from API") | |
time.sleep(30) |