File size: 4,394 Bytes
a745ced
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
from dotenv import load_dotenv
import yaml
import re
import json
import pandas as pd
from pathlib import Path
import time

from agno.agent import Agent, RunResponse
from agno.models.google import Gemini
from agno.media import Image, Audio, Video, File

from agno.tools.googlesearch import GoogleSearchTools
from agno.tools.calculator import CalculatorTools


URL_BASE_DOWNLOAD = "https://agents-course-unit4-scoring.hf.space/files/"
load_dotenv()

class AskMeAnythingAgent:
    def __init__(self, model_id: str = "gemini-2.0-flash", debug: bool = False):
        with open("instruction_prompts.yaml", "r") as file:
            self.instruction_prompts = yaml.safe_load(file)
        self.main_agent = Agent(
            name="Ask Me Anything",
            model=Gemini(id=model_id),
            role=self.instruction_prompts["role"],
            instructions=self.instruction_prompts["instructions"],
            goal=self.instruction_prompts["goal"],
            tools=[GoogleSearchTools(), CalculatorTools()],
            debug_mode=debug,
        )
    
    def __call__(self, question: str, file_name: str = "", debug: bool = False) -> str:
        payload = self.manage_payload(file_name, question)
        response: RunResponse = self.main_agent.run(message=question, **payload)
        return self.parse_response(response.content)
    
    def manage_payload(self, file_name: str, question: str):
        if file_name != "":
            if file_name.endswith(".png"):
                return {"images": [Image(url=URL_BASE_DOWNLOAD + file_name.split(".")[0], format="png")]}
            elif file_name.endswith(".mp3"):
                return {"audio": [Audio(url=URL_BASE_DOWNLOAD + file_name.split(".")[0], format="mp3")]}
            elif file_name.endswith(".xlsx"):
                df = pd.read_excel(io=URL_BASE_DOWNLOAD + file_name.split(".")[0])
                df.to_csv(f"media/{file_name.split('.')[0]}.csv", index=False)
                return {"files": [File(filepath=f"media/{file_name.split('.')[0]}.csv", mime_type="text/csv")]}
            else:
                return {"files": [File(url=URL_BASE_DOWNLOAD + file_name.split(".")[0])]}
        elif "https://www.youtube.com/" in question:
                url = re.search(r'https://www\.youtube\.com/watch\?v=[\w-]+', question).group(0)
                path_to_file = self.download_yt_video(url)
                return {"videos": [Video(filepath=path_to_file)]}
        else:
            return {}
    
    def download_yt_video(self, url: str):
        import yt_dlp
        import os

        ydl_opts = {
            'format': 'best',
            'outtmpl': 'media/%(id)s.%(ext)s',
            'quiet': True,
        }
        os.makedirs("media", exist_ok=True)
        print(f"Downloading video from {url}")
        if not os.path.exists(f"media/{url.split('v=')[1]}.mp4"):
            with yt_dlp.YoutubeDL(ydl_opts) as ydl:
                #info = ydl.extract_info(url, download=False)
                #final_filename = ydl.prepare_filename(info)
                ydl.download([url])
        return Path(__file__).parent.joinpath(f"media/{url.split('v=')[1]}.mp4")
    
    def download_file(self, file_name: str):
        import requests
        response = requests.get(URL_BASE_DOWNLOAD + file_name)
        os.makedirs("media", exist_ok=True)
        if not os.path.exists(f"media/{file_name}"):
            with open(f"media/{file_name}", 'wb') as file:
                file.write(response.content)
        return f"media/{file_name}"

    def parse_response(self, response: str):
        field_header_pattern = re.compile(r"\[\[ ## ANSWER ## \]\]|\[\[## ANSWER ##\]\]")
        for line in response.splitlines():
            match = field_header_pattern.search(line.strip())
            if match:
                return line[match.end():].strip()
        return "invalid"
            

if __name__ == "__main__":
    with open("questions.json", "r") as file:
        questions = json.load(file)
    for i, question in enumerate(questions):
        agent = AskMeAnythingAgent(debug=False)
        print(f"{i+1}. QUESTION: ", question["question"], "\nFILE: ", question["file_name"])
        print(f"{i+1}. ANSWER: ", agent(question["question"], question["file_name"]), "\n\n")
        if (i+1) % 5 == 0:
            print("Sleeping for 30 seconds to avoid RPM limit from API")
            time.sleep(30)