|
from typing import Any, Dict, List, Optional, TypedDict, Annotated |
|
import operator |
|
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.tools import tool |
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
from langgraph.graph import StateGraph, START, END |
|
import os |
|
import requests |
|
import json |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=OPENAI_API_KEY) |
|
|
|
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
|
|
class AgentState(TypedDict): |
|
question: str |
|
answer: str |
|
task_id: str |
|
log: Annotated[List[str], operator.add] |
|
|
|
def assistant(state: AgentState) -> AgentState: |
|
messages = [ |
|
SystemMessage(content="You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: <your answer here>. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."), |
|
HumanMessage(content=state["question"]) |
|
] |
|
response = llm.invoke(messages) |
|
return {"answer": response.content, "log": [f"Assistant response: {response.content}"]} |
|
|
|
|
|
def get_all_questions(): |
|
"""Fetch all questions from the API""" |
|
try: |
|
response = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15) |
|
response.raise_for_status() |
|
return response.json() |
|
except requests.exceptions.RequestException as e: |
|
print(f"Error fetching questions: {e}") |
|
return [] |
|
|
|
def get_random_question(): |
|
"""Fetch a random question from the API""" |
|
try: |
|
response = requests.get(f"{DEFAULT_API_URL}/random-question", timeout=15) |
|
response.raise_for_status() |
|
return response.json() |
|
except requests.exceptions.RequestException as e: |
|
print(f"Error fetching random question: {e}") |
|
return None |
|
|
|
def get_file_for_task(task_id: str): |
|
"""Download file associated with a task ID""" |
|
try: |
|
response = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=30) |
|
response.raise_for_status() |
|
return response.content |
|
except requests.exceptions.RequestException as e: |
|
print(f"Error fetching file for task {task_id}: {e}") |
|
return None |
|
|
|
def submit_answers(username: str, agent_code: str, answers: List[Dict]): |
|
"""Submit answers to the API""" |
|
submission_data = { |
|
"username": username, |
|
"agent_code": agent_code, |
|
"answers": answers |
|
} |
|
try: |
|
response = requests.post(f"{DEFAULT_API_URL}/submit", json=submission_data, timeout=60) |
|
response.raise_for_status() |
|
return response.json() |
|
except requests.exceptions.RequestException as e: |
|
print(f"Error submitting answers: {e}") |
|
return None |
|
|
|
|
|
graph = StateGraph(AgentState) |
|
graph.add_node("assistant", assistant) |
|
graph.add_edge(START, "assistant") |
|
graph.add_edge("assistant", END) |
|
app = graph.compile() |
|
|
|
def run_agent(question: str, task_id: str): |
|
"""Run the agent on a single question""" |
|
state = {"question": question, "task_id": task_id, "log": []} |
|
return app.invoke(state) |
|
|
|
def run_agent_on_all_questions(): |
|
"""Run the agent on all questions from the API""" |
|
print("Fetching all questions...") |
|
questions = get_all_questions() |
|
|
|
if not questions: |
|
print("No questions found or error occurred") |
|
return |
|
|
|
print(f"Found {len(questions)} questions") |
|
results = [] |
|
|
|
for i, question_data in enumerate(questions): |
|
task_id = question_data.get("task_id") |
|
question_text = question_data.get("question") |
|
|
|
if not task_id or not question_text: |
|
print(f"Skipping malformed question {i}") |
|
continue |
|
|
|
print(f"\nProcessing question {i+1}/{len(questions)}") |
|
print(f"Task ID: {task_id}") |
|
print(f"Question: {question_text[:100]}...") |
|
|
|
|
|
result = run_agent(question_text, task_id) |
|
|
|
results.append({ |
|
"task_id": task_id, |
|
"question": question_text, |
|
"answer": result["answer"], |
|
"log": result["log"] |
|
}) |
|
|
|
print(f"Answer: {result['answer']}") |
|
|
|
return results |
|
|
|
def demo_single_question(): |
|
"""Demo with a single random question""" |
|
print("Fetching a random question...") |
|
question_data = get_random_question() |
|
|
|
if not question_data: |
|
print("Could not fetch random question") |
|
return |
|
|
|
task_id = question_data.get("task_id") |
|
question_text = question_data.get("question") |
|
|
|
print(f"Task ID: {task_id}") |
|
print(f"Question: {question_text}") |
|
|
|
|
|
result = run_agent(question_text, task_id) |
|
|
|
print(f"\nAnswer: {result['answer']}") |
|
print(f"Log: {result['log']}") |
|
|
|
return result |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("=== Running on All Questions ===") |
|
results = run_agent_on_all_questions() |
|
|
|
|
|
if results: |
|
with open('agent_results.json', 'w') as f: |
|
json.dump(results, f, indent=2) |
|
print(f"\nResults saved to agent_results.json") |
|
|
|
|
|
print("=== Manual Test ===") |
|
manual_question = "What is the capital of France?" |
|
manual_task_id = "test-123" |
|
manual_result = run_agent(manual_question, manual_task_id) |
|
print(f"Question: {manual_question}") |
|
print(f"Answer: {manual_result['answer']}") |