Spaces:
Sleeping
Sleeping
import json | |
import os | |
import pickle | |
import re | |
from datetime import datetime, timedelta | |
from io import BytesIO | |
from pathlib import Path | |
from typing import List | |
import requests | |
from cachetools import TTLCache | |
from langchain.schema import Document | |
from langchain_community.vectorstores import FAISS | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessageChunk | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from langchain_core.tools import tool | |
from dotenv import load_dotenv | |
load_dotenv() | |
# ---------------------------------------------------------- | |
# 0. Constants | |
# ---------------------------------------------------------- | |
JSONL_PATH = Path("metadata.jsonl") | |
FAISS_CACHE = Path("faiss_index.pkl") | |
EMBED_MODEL = "sentence-transformers/all-mpnet-base-v2" | |
RETRIEVER_K = 5 | |
CACHE_TTL = 600 | |
CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL) | |
# ---------------------------------------------------------- | |
# 1. Build / load FAISS retriever | |
# ---------------------------------------------------------- | |
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL) | |
if FAISS_CACHE.exists(): | |
with open(FAISS_CACHE, "rb") as f: | |
vector_store = pickle.load(f) | |
else: | |
if not JSONL_PATH.exists(): | |
raise FileNotFoundError("metadata.jsonl not found") | |
docs = [] | |
with open(JSONL_PATH, "rt", encoding="utf-8") as f: | |
for line in f: | |
rec = json.loads(line) | |
content = f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}" | |
docs.append(Document(page_content=content, metadata={"source": rec["task_id"]})) | |
vector_store = FAISS.from_documents(docs, embeddings) | |
with open(FAISS_CACHE, "wb") as f: | |
pickle.dump(vector_store, f) | |
retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K}) | |
# ---------------------------------------------------------- | |
# 2. Caching helper | |
# ---------------------------------------------------------- | |
def cached_get(key: str, fetch_fn): | |
if key in CACHE: | |
return CACHE[key] | |
val = fetch_fn() | |
CACHE[key] = val | |
return val | |
# ---------------------------------------------------------- | |
# 3. Tools | |
# ---------------------------------------------------------- | |
def python_repl(code: str) -> str: | |
"""Execute Python code and return stdout/stderr.""" | |
import subprocess, textwrap | |
code = textwrap.dedent(code).strip() | |
try: | |
result = subprocess.run( | |
["python", "-c", code], | |
capture_output=True, | |
text=True, | |
timeout=5, | |
) | |
return result.stdout if not result.stderr else f"STDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}" | |
except subprocess.TimeoutExpired: | |
return "Execution timed out (>5s)." | |
def describe_image(image_source: str) -> str: | |
"""Describe an image from local path or URL with Gemini vision.""" | |
import base64 | |
from PIL import Image | |
if image_source.startswith("http"): | |
img = Image.open(BytesIO(requests.get(image_source, timeout=10).content)) | |
else: | |
img = Image.open(image_source) | |
buffered = BytesIO() | |
img.convert("RGB").save(buffered, format="JPEG") | |
b64 = base64.b64encode(buffered.getvalue()).decode() | |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0) | |
msg = HumanMessage( | |
content=[ | |
{"type": "text", "text": "Describe this image in detail."}, | |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64}"}}, | |
] | |
) | |
return llm.invoke([msg]).content | |
def web_search(query: str) -> str: | |
"""Smart web search with 3 keyword variants, cached.""" | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
keywords = [query, query.replace(" ", " OR "), f'"{query}"'] | |
seen = set() | |
results = [] | |
for kw in keywords: | |
key = f"web:{kw}" | |
snippets = cached_get( | |
key, | |
lambda: TavilySearchResults(max_results=3, include_raw_content=True).invoke(kw), | |
) | |
for s in snippets: | |
if s["url"] not in seen: | |
seen.add(s["url"]) | |
results.append(s["content"][:2000]) | |
if len(results) >= 5: | |
break | |
return "\n\n---\n\n".join(results) | |
def wiki_search(query: str) -> str: | |
from langchain_community.document_loaders import WikipediaLoader | |
key = f"wiki:{query}" | |
docs = cached_get( | |
key, | |
lambda: WikipediaLoader(query=query, load_max_docs=2).load(), | |
) | |
return "\n\n---\n\n".join( | |
f'<Document source="{d.metadata.get("source", "")}">\n{d.page_content}\n</Document>' | |
for d in docs | |
) | |
def arxiv_search(query: str) -> str: | |
from langchain_community.document_loaders import ArxivLoader | |
key = f"arxiv:{query}" | |
docs = cached_get( | |
key, | |
lambda: ArxivLoader(query=query, load_max_docs=2).load(), | |
) | |
return "\n\n---\n\n".join( | |
f'<Document source="{d.metadata.get("source", "")}">\n{d.page_content[:2000]}...\n</Document>' | |
for d in docs | |
) | |
# ---------------------------------------------------------- | |
# 4. System prompt | |
# ---------------------------------------------------------- | |
SYSTEM_PROMPT = ( | |
"""You are a helpful assistant tasked with answering questions using a set of tools. | |
Your final answer must strictly follow this format: | |
FINAL ANSWER: [ANSWER] | |
Only write the answer in that exact format. Do not explain anything. Do not include any other text. | |
If you are provided with a similar question and its final answer, and the current question is **exactly the same**, then simply return the same final answer without using any tools. | |
Only use tools if the current question is different from the similar one. | |
Examples: | |
- FINAL ANSWER: FunkMonk | |
- FINAL ANSWER: Paris | |
- FINAL ANSWER: 128 | |
If you do not follow this format exactly, your response will be considered incorrect. | |
""" | |
) | |
# ---------------------------------------------------------- | |
# 5. Manual LangGraph construction | |
# ---------------------------------------------------------- | |
tools_list = [python_repl, describe_image, web_search, wiki_search, arxiv_search] | |
# retriever tool | |
from langchain.tools.retriever import create_retriever_tool | |
tools_list.append( | |
create_retriever_tool( | |
retriever=retriever, | |
name="retrieve_examples", | |
description="Retrieve up to 5 solved questions similar to the user query.", | |
) | |
) | |
# ---------------------------------------------------------- | |
# provider switcher | |
# ---------------------------------------------------------- | |
def build_llm(provider: str = "groq"): | |
if provider == "google": | |
return ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0) | |
elif provider == "groq": | |
return ChatGroq(model="llama-3.3-70b-versatile", temperature=0) | |
elif provider == "huggingface": | |
return ChatHuggingFace( | |
llm=HuggingFaceEndpoint( | |
repo_id="Qwen/Qwen2.5-Coder-32B-Instruct", | |
temperature=0, | |
) | |
) | |
else: | |
raise ValueError("provider must be 'google', 'groq', or 'huggingface'") | |
llm = build_llm("google") # or "groq", "huggingface" | |
llm_with_tools = llm.bind_tools(tools_list) | |
def assistant(state: MessagesState): | |
"""LLM node that can call tools.""" | |
return {"messages": [llm_with_tools.invoke(state["messages"])]} | |
def retriever_node(state: MessagesState): | |
"""First node: fetch examples and prepend them.""" | |
user_query = state["messages"][-1].content | |
docs = retriever.invoke(user_query) | |
if docs: | |
example_text = "\n\n---\n\n".join(d.page_content for d in docs) | |
example_msg = HumanMessage( | |
content=f"Here are {len(docs)} similar solved examples:\n\n{example_text}" | |
) | |
return {"messages": [SYSTEM_PROMPT] + state["messages"] + [example_msg]} | |
return {"messages": [SYSTEM_PROMPT] + state["messages"]} | |
builder = StateGraph(MessagesState) | |
builder.add_node("retriever", retriever_node) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools_list)) | |
builder.add_edge(START, "retriever") | |
builder.add_edge("retriever", "assistant") | |
builder.add_conditional_edges("assistant", tools_condition) | |
builder.add_edge("tools", "assistant") | |
agent = builder.compile() | |
# ---------------------------------------------------------- | |
# 6. Quick streaming test | |
# ---------------------------------------------------------- | |
if __name__ == "__main__": | |
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" | |
print("Agent thinking …") | |
for chunk in agent.stream({"messages": [("user", question)]}, stream_mode="values"): | |
last = chunk["messages"][-1] | |
if hasattr(last, "content"): | |
print(last.content, end="", flush=True) |