medmiu's picture
Rename app.py to app4.py
03ea85b verified
import math
#from typing import Optional, Tuple
import smolagents
#from smolagents import tool
#import smolagents[litellm]
import os
import re
import requests
import gradio as gr
from langchain_community.chat_models import ChatHuggingFace
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.agents import Tool, AgentExecutor, initialize_agent
from langchain.agents import AgentType
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage
from langchain.memory import ConversationBufferWindowMemory
from youtube_transcript_api import YouTubeTranscriptApi
import pytesseract
import cv2
import pandas as pd
from langchain.tools import tool
# === Configuration ===
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN")
'''
# === Strict Formatting Setup ===
SYSTEM_PROMPT = """You MUST format responses EXACTLY as:
FINAL ANSWER: [answer]
Rules:
1. Always begin with "FINAL ANSWER:"
2. Never include other text before/after
3. Numbers: Plain (42)
4. Strings: Minimal (Paris)
5. Lists: Comma-separated (5, apple, 10)
6. No markdown or special formatting"""
def enforce_final_answer_format(response: str) -> str:
"""Ensures response follows FINAL ANSWER: format"""
response = response.strip()
# Already formatted correctly
if response.startswith("FINAL ANSWER:"):
return response
# Try to extract answer from LLM output
match = re.search(r"(?:FINAL ANSWER:|Answer:|Output:)\s*(.*)", response, re.IGNORECASE)
if match:
return f"FINAL ANSWER: {match.group(1).strip()}"
# Fallback to last non-empty line
lines = [line.strip() for line in response.split('\n') if line.strip()]
return f"FINAL ANSWER: {lines[-1] if lines else 'No answer generated'}"
class StrictFormatChatHuggingFace(ChatHuggingFace):
"""Wrapper that enforces FINAL ANSWER format"""
def _call(self, prompt: str, stop: list = None) -> str:
response = super()._call(prompt, stop)
return enforce_final_answer_format(response)
# === LLM Initialization ===
llm = HuggingFaceEndpoint(
repo_id="Qwen/Qwen1.5-7B-Chat",
temperature=0.1,
max_new_tokens=256,
top_k=10,
repetition_penalty=1.1,
)
chat_model = StrictFormatChatHuggingFace(llm=llm)
'''
# === Tools Setup ===
@tool
def wikipedia_search(query: str) -> str:
"""Search Wikipedia and return summary."""
return WikipediaAPIWrapper().run(query)
@tool
def web_search(query: str) -> str:
"""Search the web using DuckDuckGo."""
return DuckDuckGoSearchRun().run(query)
@tool
def youtube_transcript(url: str) -> str:
"""Extract transcript from a YouTube video URL."""
video_id = url.split("v=")[-1]
transcript = YouTubeTranscriptApi.get_transcript(video_id)
return "\n".join([x["text"] for x in transcript])
@tool
def image_ocr(path: str) -> str:
"""Extract text from an image file."""
img = cv2.imread(path)
return pytesseract.image_to_string(img)
@tool
def read_excel(path: str) -> str:
"""Read contents of an Excel (.xlsx) file."""
df = pd.read_excel(path)
return df.to_string()
@tool
def reverse_text(text: str) -> str:
"""Reverse the text if it looks reversed."""
reversed_candidate = text[::-1]
if " " in reversed_candidate:
return f"Reversed detected. Corrected: {reversed_candidate}"
return text
@tool
def vector_search(query: str) -> str:
"""Search in example documents using vector similarity."""
docs = [
"Machine learning involves training algorithms on data.",
"Neural networks are a part of deep learning.",
"Supervised learning uses labeled datasets."
]
embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectordb = Chroma.from_texts(docs, embedding=embed)
results = vectordb.similarity_search(query, k=2)
return "\n".join([r.page_content for r in results])
@tool
def math_calc(expression: str) -> str:
"""Evaluate a math expression safely."""
allowed_chars = set('0123456789+-*/.() ')
if not all(c in allowed_chars for c in expression):
return "Error: Invalid characters in expression"
try:
return str(eval(expression, {"__builtins__": None}, {}))
except Exception as e:
return f"Error: {str(e)}"
@tool
def python_eval(code: str) -> str:
"""Evaluate basic Python code safely."""
try:
return str(eval(code, {"__builtins__": None}, {}))
except Exception as e:
return f"Error: {str(e)}"
# === Agent Initialization ===
tools = [
wikipedia_search,
web_search,
youtube_transcript,
image_ocr,
read_excel,
reverse_text,
vector_search,
math_calc,
python_eval
]
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=SYSTEM_PROMPT),
("human", "{input}"),
("ai", "{agent_scratchpad}")
])
model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together")
agent = CodeAgent(
model=model,
tools=tools,
max_steps=20
)
'''
agent = initialize_agent(
tools=tools,
llm=chat_model,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
memory=ConversationBufferWindowMemory(
memory_key="chat_history",
k=3,
return_messages=True
),
agent_kwargs={
"system_message": SystemMessage(content=SYSTEM_PROMPT),
"prompt": prompt
},
handle_parsing_errors=True
)
'''
# === Gradio Interface ===
def get_agent_response(question_text):
try:
response = agent.invoke({"input": question_text})
return enforce_final_answer_format(response["output"])
except Exception as e:
return f"FINAL ANSWER: Error processing request: {str(e)}"
with gr.Blocks() as demo:
gr.Markdown("# Strict Format Agent")
gr.Markdown("This agent enforces FINAL ANSWER: format for all responses")
with gr.Row():
question = gr.Textbox(label="Your Question")
output = gr.Textbox(label="Agent Response")
submit = gr.Button("Submit")
submit.click(fn=get_agent_response, inputs=question, outputs=output)
if __name__ == "__main__":
demo.launch()