kit086 commited on
Commit
ba5efb5
·
1 Parent(s): 7d7185d

feat: 优化

Browse files
Files changed (8) hide show
  1. .gitignore +1 -0
  2. agent.py +227 -91
  3. app.py +5 -4
  4. metadata.jsonl +0 -0
  5. pyproject.toml +3 -0
  6. supabase_docs.csv +0 -0
  7. system_prompt.txt +18 -0
  8. uv.lock +0 -0
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  .venv
2
  .env
3
  **/__pycache__
 
 
1
  .venv
2
  .env
3
  **/__pycache__
4
+ chroma_db
agent.py CHANGED
@@ -1,97 +1,233 @@
1
- """agent.py
2
- High-level Agent wrapper used by `app.py`.
3
-
4
- Current implementation leverages:
5
- • Gemini 1.5 Flash via `ChatGoogleGenerativeAI` (requires `GOOGLE_API_KEY` env).
6
- • The LangChain `zero-shot-react-description` agent (simple & robust).
7
- • Tools defined in `tools.create_tools`.
8
-
9
- Later we can migrate the control loop to LangGraph, but this version already
10
- provides a working agent that meets the API expectations (callable returning a
11
- plain string answer).
12
- """
13
- from __future__ import annotations
14
-
15
  import os
16
- import re
17
- from typing import List
18
-
19
  from dotenv import load_dotenv
20
- from langchain.agents import AgentType, initialize_agent
 
 
21
  from langchain_google_genai import ChatGoogleGenerativeAI
22
- from langchain_core.callbacks import BaseCallbackHandler
23
-
24
- from tools import create_tools # noqa: E402
25
-
26
- # -----------------------------------------------------------------------------
27
- # Callback for minimal logging (optional)
28
- # -----------------------------------------------------------------------------
29
- class PrintCallback(BaseCallbackHandler):
30
- """Simple callback that prints agent thoughts for debugging."""
31
-
32
- def on_llm_new_token(self, token: str, **kwargs): # noqa: D401
33
- print(token, end="", flush=True)
34
-
35
-
36
- # -----------------------------------------------------------------------------
37
- # Helper to strip template markers from final answer
38
- # -----------------------------------------------------------------------------
39
- _SYSTEM_PROMPT = (
40
- "You are a general AI assistant. I will ask you a question. "
41
- "Report your thoughts, and finish your answer with the following template: "
42
- "FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR "
43
- "as few words as possible OR a comma separated list of numbers and/or "
44
- "strings. If you are asked for a number, don't use comma to write your "
45
- "number neither use units such as $ or percent sign unless specified "
46
- "otherwise. If you are asked for a string, don't use articles, neither "
47
- "abbreviations (e.g. for cities), and write the digits in plain text unless "
48
- "specified otherwise. If you are asked for a comma separated list, apply "
49
- "the above rules depending of whether the element to be put in the list is "
50
- "a number or a string."
51
- )
52
- _FINAL_PATTERN = re.compile(r"FINAL ANSWER:\s*(.*)", re.IGNORECASE | re.DOTALL)
53
-
54
-
55
- class Agent:
56
- """High-level callable Agent used by `app.py`."""
57
-
58
- def __init__(self, *, temperature: float = 0.0):
59
- # Ensure env vars are loaded
60
- load_dotenv()
61
-
62
- api_key = os.getenv("GOOGLE_API_KEY")
63
- if not api_key:
64
- raise EnvironmentError("GOOGLE_API_KEY not found in environment or .env file.")
65
-
66
- # Initialise LLM
67
- self.llm = ChatGoogleGenerativeAI(
68
- model="gemini-2.5-flash",
69
- temperature=temperature,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
-
72
- # Aggregate tools
73
- self.tools = create_tools()
74
-
75
- # Build agent executor (Zero-Shot ReAct)
76
- self.agent_executor = initialize_agent(
77
- self.tools,
78
- self.llm,
79
- agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
80
- verbose=False,
81
- handle_parsing_errors=True,
82
- system_message=_SYSTEM_PROMPT,
83
- callbacks=[PrintCallback()],
84
  )
 
 
 
 
 
 
 
 
 
 
85
 
86
- # ------------------------------------------------------------------
87
- # Public API
88
- # ------------------------------------------------------------------
89
- def __call__(self, question: str) -> str: # noqa: D401
90
- """Return the agent's answer as a plain string (no prefix)."""
91
- print(f"Agent received question: {question[:80]}…")
92
- raw_answer: str = self.agent_executor.run(question)
93
- # Post-process to remove leading template if present
94
- match = _FINAL_PATTERN.search(raw_answer)
95
- answer = match.group(1).strip() if match else raw_answer.strip()
96
- print(f"Agent final answer: {answer}")
97
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
 
 
 
2
  from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_huggingface import HuggingFaceEmbeddings
8
+ from langchain_community.tools.tavily_search import TavilySearchResults
9
+ from langchain_community.document_loaders import WikipediaLoader
10
+ from langchain_community.document_loaders import ArxivLoader
11
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
12
+ from langchain_core.tools import tool
13
+ from langchain.tools.retriever import create_retriever_tool
14
+ from langchain_community.vectorstores import Chroma
15
+ from langchain_core.documents import Document
16
+ import shutil
17
+ import pandas as pd # Ny import för pandas
18
+ import json # För att parsa metadata-kolumnen
19
+
20
+ load_dotenv()
21
+
22
+ # Tools:
23
+ @tool
24
+ def multiply(a: int, b: int) -> int:
25
+ """Multiply two numbers.
26
+ Args:
27
+ a: first int
28
+ b: second int
29
+ """
30
+ return a * b
31
+
32
+ @tool
33
+ def add(a: int, b: int) -> int:
34
+ """Add two numbers.
35
+
36
+ Args:
37
+ a: first int
38
+ b: second int
39
+ """
40
+ return a + b
41
+
42
+ @tool
43
+ def subtract(a: int, b: int) -> int:
44
+ """Subtract two numbers.
45
+
46
+ Args:
47
+ a: first int
48
+ b: second int
49
+ """
50
+ return a - b
51
+
52
+ @tool
53
+ def divide(a: int, b: int) -> int:
54
+ """Divide two numbers.
55
+
56
+ Args:
57
+ a: first int
58
+ b: second int
59
+ """
60
+ if b == 0:
61
+ raise ValueError("Cannot divide by zero.")
62
+ return a / b
63
+
64
+ @tool
65
+ def modulus(a: int, b: int) -> int:
66
+ """Get the modulus of two numbers.
67
+
68
+ Args:
69
+ a: first int
70
+ b: second int
71
+ """
72
+ return a % b
73
+
74
+ @tool
75
+ def wiki_search(query: str) -> str:
76
+ """Search Wikipedia for a query and return maximum 2 results.
77
+
78
+ Args:
79
+ query: The search query."""
80
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
81
+ formatted_search_docs = "\n\n---\n\n".join(
82
+ [
83
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
84
+ for doc in search_docs
85
+ ])
86
+ return {"wiki_results": formatted_search_docs}
87
+
88
+ @tool
89
+ def web_search(query: str) -> str:
90
+ """Search Tavily for a query and return maximum 3 results.
91
+
92
+ Args:
93
+ query: The search query."""
94
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
95
+ formatted_search_docs = "\n\n---\n\n".join(
96
+ [
97
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
98
+ for doc in search_docs
99
+ ])
100
+ return {"web_results": formatted_search_docs}
101
+
102
+ @tool
103
+ def arvix_search(query: str) -> str:
104
+ """Search Arxiv for a query and return maximum 3 result.
105
+
106
+ Args:
107
+ query: The search query."""
108
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
109
+ formatted_search_docs = "\n\n---\n\n".join(
110
+ [
111
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
112
+ for doc in search_docs
113
+ ])
114
+ return {"arvix_results": formatted_search_docs}
115
+
116
+ # load the system prompt from the file
117
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
+ system_prompt = f.read()
119
+
120
+ # Retrieval
121
+ CHROMA_DIR = "./chroma_db"
122
+ CSV_PATH = "./supabase_docs.csv"
123
+ EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2"
124
+ _SIMILARITY_THRESHOLD = 0.2 # lower distance means more similar
125
+
126
+ embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
127
+
128
+ if os.path.exists(CHROMA_DIR):
129
+ print(f"Loading existing ChromaDB from {CHROMA_DIR}")
130
+ vector_store = Chroma(
131
+ persist_directory=CHROMA_DIR,
132
+ embedding_function=embeddings,
133
+ )
134
+ else:
135
+ print(f"Creating new ChromaDB at {CHROMA_DIR}, and loading documents from {CSV_PATH}")
136
+ if os.path.exists(CHROMA_DIR):
137
+ shutil.rmtree(CHROMA_DIR)
138
+ os.makedirs(CHROMA_DIR)
139
+
140
+ if not os.path.exists(CSV_PATH):
141
+ raise FileNotFoundError(f"CSV file {CSV_PATH} does not exist")
142
+
143
+ df = pd.read_csv(CSV_PATH)
144
+ documents = []
145
+ for i, row in df.iterrows():
146
+ content = row["content"]
147
+
148
+ question_part = content.split("Final answer :")[0].strip()
149
+ final_answer_part = content.split("Final answer :")[-1].strip() if "Final answer :" in content else ""
150
+
151
+ try:
152
+ metadata = json.loads(row["metadata"].replace("'", '"'))
153
+ except json.JSONDecodeError:
154
+ metadata = {}
155
+
156
+ metadata["final_answer"] = final_answer_part
157
+
158
+ documents.append(Document(page_content=question_part, metadata=metadata))
159
+
160
+ if not documents:
161
+ print("No documents loaded from CSV. ChromaDB will be empty.")
162
+
163
+ vector_store = Chroma(
164
+ persist_directory=CHROMA_DIR,
165
+ embedding_function=embeddings
166
  )
167
+ else:
168
+ vector_store = Chroma.from_documents(
169
+ documents=documents,
170
+ embedding=embeddings,
171
+ persist_directory=CHROMA_DIR,
 
 
 
 
 
 
 
 
172
  )
173
+ vector_store.persist()
174
+ print(f"ChromaDB initialized and persisted with {len(documents)} documents from CSV.")
175
+
176
+
177
+ # Retriever tool
178
+ retriever_tool = create_retriever_tool(
179
+ retriever = vector_store.as_retriever(),
180
+ name = "Question_Search",
181
+ description = "A tool to retrieve similar questions from a vector store. The retrieved document's metadata contains the 'final_answer' to the question."
182
+ )
183
 
184
+ # Agent
185
+
186
+ tools = [
187
+ multiply,
188
+ add,
189
+ subtract,
190
+ divide,
191
+ modulus,
192
+ wiki_search,
193
+ web_search,
194
+ arvix_search,
195
+ retriever_tool,
196
+ ]
197
+
198
+ def build_graph_agent():
199
+ llm = ChatGoogleGenerativeAI(
200
+ model_name="gemini-1.5-flash",
201
+ temperature=0.0,
202
+ )
203
+
204
+ llm_with_tools = llm.bind_tools(tools)
205
+
206
+ def assistant(state: MessagesState):
207
+ return {
208
+ "messages": [llm_with_tools.invoke(state["messages"])],
209
+ }
210
+
211
+ def retriever(state: MessagesState):
212
+ query = state["messages"][-1].content
213
+ similar_docs = vector_store.similarity_search(query, k=3)
214
+
215
+ if similar_docs:
216
+ similar_doc = similar_docs[0]
217
+ if "final_answer" in similar_doc.metadata and similar_doc.metadata["final_answer"]:
218
+ answer = similar_doc.metadata["final_answer"]
219
+ elif "Final answer :" in similar_doc.page_content:
220
+ answer = similar_doc.page_content.split("Final answer :")[-1].strip()
221
+ else:
222
+ answer = similar_doc.page_content.strip()
223
+
224
+ return {"messages": [AIMessage(content=answer)]}
225
+ else:
226
+ return {"messages": [AIMessage(content="No similar questions found in the knowledge base.")]}
227
+
228
+ builder = StateGraph(MessagesState)
229
+ builder.add_node("retriever", retriever)
230
+ builder.set_entry_point("retriever")
231
+ builder.set_finish_point("retriever")
232
+
233
+ return builder.compile()
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  from dotenv import load_dotenv
6
- from agent import Agent
7
  import pandas as pd
8
  import time
9
 
@@ -16,6 +16,7 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
16
  class BasicAgent:
17
  def __init__(self):
18
  print("BasicAgent initialized.")
 
19
  def __call__(self, question: str) -> str:
20
  print(f"Agent received question (first 50 chars): {question[:50]}...")
21
  fixed_answer = "This is a default answer."
@@ -44,7 +45,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
44
  # 1. Instantiate Agent ( modify this part to create your agent)
45
  try:
46
  load_dotenv()
47
- agent = Agent()
48
  except Exception as e:
49
  print(f"Error instantiating agent: {e}")
50
  return f"Error initializing agent: {e}", None
@@ -90,8 +91,8 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
90
  except Exception as e:
91
  print(f"Error running agent on task {task_id}: {e}")
92
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
93
- # wait 60s before next question
94
- time.sleep(60)
95
 
96
  if not answers_payload:
97
  print("Agent did not produce any answers to submit.")
 
3
  import requests
4
  import inspect
5
  from dotenv import load_dotenv
6
+ from agent import build_graph_agent
7
  import pandas as pd
8
  import time
9
 
 
16
  class BasicAgent:
17
  def __init__(self):
18
  print("BasicAgent initialized.")
19
+ self.graph = build_graph_agent()
20
  def __call__(self, question: str) -> str:
21
  print(f"Agent received question (first 50 chars): {question[:50]}...")
22
  fixed_answer = "This is a default answer."
 
45
  # 1. Instantiate Agent ( modify this part to create your agent)
46
  try:
47
  load_dotenv()
48
+ agent = BasicAgent()
49
  except Exception as e:
50
  print(f"Error instantiating agent: {e}")
51
  return f"Error initializing agent: {e}", None
 
91
  except Exception as e:
92
  print(f"Error running agent on task {task_id}: {e}")
93
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
94
+ # wait 20s before next question
95
+ time.sleep(10)
96
 
97
  if not answers_payload:
98
  print("Agent did not produce any answers to submit.")
metadata.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -5,14 +5,17 @@ description = "Add your description here"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
 
8
  "ddgs>=9.0.0",
9
  "duckduckgo-search>=8.1.1",
10
  "gradio[oauth]>=5.36.2",
11
  "langchain>=0.3.26",
 
12
  "langchain-community>=0.3.27",
13
  "langchain-experimental>=0.3.4",
14
  "langchain-google-genai>=2.1.7",
15
  "langchain-huggingface>=0.3.0",
16
  "langgraph>=0.5.2",
17
  "requests>=2.32.4",
 
18
  ]
 
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
+ "chromadb>=1.0.15",
9
  "ddgs>=9.0.0",
10
  "duckduckgo-search>=8.1.1",
11
  "gradio[oauth]>=5.36.2",
12
  "langchain>=0.3.26",
13
+ "langchain-chroma>=0.2.4",
14
  "langchain-community>=0.3.27",
15
  "langchain-experimental>=0.3.4",
16
  "langchain-google-genai>=2.1.7",
17
  "langchain-huggingface>=0.3.0",
18
  "langgraph>=0.5.2",
19
  "requests>=2.32.4",
20
+ "sentence-transformers>=5.0.0",
21
  ]
supabase_docs.csv ADDED
The diff for this file is too large to render. See raw diff
 
system_prompt.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ You are a helpful assistant tasked with answering questions using a set of tools.
2
+
3
+ Your final answer must strictly follow this format:
4
+ FINAL ANSWER: [ANSWER]
5
+
6
+ Only write the answer in that exact format. Do not explain anything. Do not include any other text.
7
+
8
+ If you are provided with a similar question and its final answer, and the current question is **exactly the same**, then simply return the same final answer without using any tools.
9
+
10
+ Only use tools if the current question is different from the similar one.
11
+
12
+ Examples:
13
+ - FINAL ANSWER: FunkMonk
14
+ - FINAL ANSWER: Paris
15
+ - FINAL ANSWER: 128
16
+
17
+ If you do not follow this format exactly, your response will be considered incorrect.
18
+
uv.lock CHANGED
The diff for this file is too large to render. See raw diff