|
|
|
import math |
|
|
|
import smolagents |
|
|
|
|
|
import os |
|
import re |
|
import requests |
|
import gradio as gr |
|
from langchain_community.chat_models import ChatHuggingFace |
|
from langchain_community.llms import HuggingFaceEndpoint |
|
from langchain_community.tools import DuckDuckGoSearchRun |
|
from langchain_community.utilities import WikipediaAPIWrapper |
|
from langchain_community.vectorstores import Chroma |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain.agents import Tool, AgentExecutor, initialize_agent |
|
from langchain.agents import AgentType |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.messages import SystemMessage |
|
from langchain.memory import ConversationBufferWindowMemory |
|
from youtube_transcript_api import YouTubeTranscriptApi |
|
import pytesseract |
|
import cv2 |
|
import pandas as pd |
|
from langchain.tools import tool |
|
|
|
|
|
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
''' |
|
# === Strict Formatting Setup === |
|
SYSTEM_PROMPT = """You MUST format responses EXACTLY as: |
|
|
|
FINAL ANSWER: [answer] |
|
|
|
Rules: |
|
1. Always begin with "FINAL ANSWER:" |
|
2. Never include other text before/after |
|
3. Numbers: Plain (42) |
|
4. Strings: Minimal (Paris) |
|
5. Lists: Comma-separated (5, apple, 10) |
|
6. No markdown or special formatting""" |
|
|
|
def enforce_final_answer_format(response: str) -> str: |
|
"""Ensures response follows FINAL ANSWER: format""" |
|
response = response.strip() |
|
|
|
# Already formatted correctly |
|
if response.startswith("FINAL ANSWER:"): |
|
return response |
|
|
|
# Try to extract answer from LLM output |
|
match = re.search(r"(?:FINAL ANSWER:|Answer:|Output:)\s*(.*)", response, re.IGNORECASE) |
|
if match: |
|
return f"FINAL ANSWER: {match.group(1).strip()}" |
|
|
|
# Fallback to last non-empty line |
|
lines = [line.strip() for line in response.split('\n') if line.strip()] |
|
return f"FINAL ANSWER: {lines[-1] if lines else 'No answer generated'}" |
|
|
|
class StrictFormatChatHuggingFace(ChatHuggingFace): |
|
"""Wrapper that enforces FINAL ANSWER format""" |
|
def _call(self, prompt: str, stop: list = None) -> str: |
|
response = super()._call(prompt, stop) |
|
return enforce_final_answer_format(response) |
|
|
|
# === LLM Initialization === |
|
llm = HuggingFaceEndpoint( |
|
repo_id="Qwen/Qwen1.5-7B-Chat", |
|
temperature=0.1, |
|
max_new_tokens=256, |
|
top_k=10, |
|
repetition_penalty=1.1, |
|
) |
|
|
|
chat_model = StrictFormatChatHuggingFace(llm=llm) |
|
''' |
|
|
|
|
|
|
|
@tool |
|
def wikipedia_search(query: str) -> str: |
|
"""Search Wikipedia and return summary.""" |
|
return WikipediaAPIWrapper().run(query) |
|
|
|
@tool |
|
def web_search(query: str) -> str: |
|
"""Search the web using DuckDuckGo.""" |
|
return DuckDuckGoSearchRun().run(query) |
|
|
|
@tool |
|
def youtube_transcript(url: str) -> str: |
|
"""Extract transcript from a YouTube video URL.""" |
|
video_id = url.split("v=")[-1] |
|
transcript = YouTubeTranscriptApi.get_transcript(video_id) |
|
return "\n".join([x["text"] for x in transcript]) |
|
|
|
@tool |
|
def image_ocr(path: str) -> str: |
|
"""Extract text from an image file.""" |
|
img = cv2.imread(path) |
|
return pytesseract.image_to_string(img) |
|
|
|
@tool |
|
def read_excel(path: str) -> str: |
|
"""Read contents of an Excel (.xlsx) file.""" |
|
df = pd.read_excel(path) |
|
return df.to_string() |
|
|
|
@tool |
|
def reverse_text(text: str) -> str: |
|
"""Reverse the text if it looks reversed.""" |
|
reversed_candidate = text[::-1] |
|
if " " in reversed_candidate: |
|
return f"Reversed detected. Corrected: {reversed_candidate}" |
|
return text |
|
|
|
@tool |
|
def vector_search(query: str) -> str: |
|
"""Search in example documents using vector similarity.""" |
|
docs = [ |
|
"Machine learning involves training algorithms on data.", |
|
"Neural networks are a part of deep learning.", |
|
"Supervised learning uses labeled datasets." |
|
] |
|
embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
vectordb = Chroma.from_texts(docs, embedding=embed) |
|
results = vectordb.similarity_search(query, k=2) |
|
return "\n".join([r.page_content for r in results]) |
|
|
|
@tool |
|
def math_calc(expression: str) -> str: |
|
"""Evaluate a math expression safely.""" |
|
allowed_chars = set('0123456789+-*/.() ') |
|
if not all(c in allowed_chars for c in expression): |
|
return "Error: Invalid characters in expression" |
|
try: |
|
return str(eval(expression, {"__builtins__": None}, {})) |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
@tool |
|
def python_eval(code: str) -> str: |
|
"""Evaluate basic Python code safely.""" |
|
try: |
|
return str(eval(code, {"__builtins__": None}, {})) |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
tools = [ |
|
wikipedia_search, |
|
web_search, |
|
youtube_transcript, |
|
image_ocr, |
|
read_excel, |
|
reverse_text, |
|
vector_search, |
|
math_calc, |
|
python_eval |
|
] |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
SystemMessage(content=SYSTEM_PROMPT), |
|
("human", "{input}"), |
|
("ai", "{agent_scratchpad}") |
|
]) |
|
model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together") |
|
agent = CodeAgent( |
|
model=model, |
|
tools=tools, |
|
max_steps=20 |
|
) |
|
''' |
|
agent = initialize_agent( |
|
tools=tools, |
|
llm=chat_model, |
|
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION, |
|
verbose=True, |
|
memory=ConversationBufferWindowMemory( |
|
memory_key="chat_history", |
|
k=3, |
|
return_messages=True |
|
), |
|
agent_kwargs={ |
|
"system_message": SystemMessage(content=SYSTEM_PROMPT), |
|
"prompt": prompt |
|
}, |
|
handle_parsing_errors=True |
|
) |
|
''' |
|
|
|
|
|
def get_agent_response(question_text): |
|
try: |
|
response = agent.invoke({"input": question_text}) |
|
return enforce_final_answer_format(response["output"]) |
|
except Exception as e: |
|
return f"FINAL ANSWER: Error processing request: {str(e)}" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Strict Format Agent") |
|
gr.Markdown("This agent enforces FINAL ANSWER: format for all responses") |
|
|
|
with gr.Row(): |
|
question = gr.Textbox(label="Your Question") |
|
output = gr.Textbox(label="Agent Response") |
|
|
|
submit = gr.Button("Submit") |
|
submit.click(fn=get_agent_response, inputs=question, outputs=output) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |