File size: 6,464 Bytes
93fe2e2
ba1447c
636adbc
e643502
 
 
cd26f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556300c
cd26f3e
 
 
636adbc
cd26f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636adbc
 
cd26f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636adbc
 
 
457125a
 
636adbc
 
cd26f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636adbc
cd26f3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

import math
#from typing import Optional, Tuple
import smolagents
#from smolagents import tool
#import smolagents[litellm]
import os
import re
import requests
import gradio as gr
from langchain_community.chat_models import ChatHuggingFace
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.agents import Tool, AgentExecutor, initialize_agent
from langchain.agents import AgentType
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage
from langchain.memory import ConversationBufferWindowMemory
from youtube_transcript_api import YouTubeTranscriptApi
import pytesseract
import cv2
import pandas as pd
from langchain.tools import tool

# === Configuration ===
os.environ["HUGGINGFACEHUB_API_TOKEN"] = os.getenv("HUGGINGFACEHUB_API_TOKEN")
'''
# === Strict Formatting Setup ===
SYSTEM_PROMPT = """You MUST format responses EXACTLY as:

FINAL ANSWER: [answer]

Rules:
1. Always begin with "FINAL ANSWER:"
2. Never include other text before/after
3. Numbers: Plain (42)
4. Strings: Minimal (Paris)
5. Lists: Comma-separated (5, apple, 10)
6. No markdown or special formatting"""

def enforce_final_answer_format(response: str) -> str:
    """Ensures response follows FINAL ANSWER: format"""
    response = response.strip()
    
    # Already formatted correctly
    if response.startswith("FINAL ANSWER:"):
        return response
    
    # Try to extract answer from LLM output
    match = re.search(r"(?:FINAL ANSWER:|Answer:|Output:)\s*(.*)", response, re.IGNORECASE)
    if match:
        return f"FINAL ANSWER: {match.group(1).strip()}"
    
    # Fallback to last non-empty line
    lines = [line.strip() for line in response.split('\n') if line.strip()]
    return f"FINAL ANSWER: {lines[-1] if lines else 'No answer generated'}"

class StrictFormatChatHuggingFace(ChatHuggingFace):
    """Wrapper that enforces FINAL ANSWER format"""
    def _call(self, prompt: str, stop: list = None) -> str:
        response = super()._call(prompt, stop)
        return enforce_final_answer_format(response)

# === LLM Initialization ===
llm = HuggingFaceEndpoint(
    repo_id="Qwen/Qwen1.5-7B-Chat",
    temperature=0.1,
    max_new_tokens=256,
    top_k=10,
    repetition_penalty=1.1,
)

chat_model = StrictFormatChatHuggingFace(llm=llm)
'''


# === Tools Setup ===
@tool
def wikipedia_search(query: str) -> str:
    """Search Wikipedia and return summary."""
    return WikipediaAPIWrapper().run(query)

@tool
def web_search(query: str) -> str:
    """Search the web using DuckDuckGo."""
    return DuckDuckGoSearchRun().run(query)

@tool
def youtube_transcript(url: str) -> str:
    """Extract transcript from a YouTube video URL."""
    video_id = url.split("v=")[-1]
    transcript = YouTubeTranscriptApi.get_transcript(video_id)
    return "\n".join([x["text"] for x in transcript])

@tool
def image_ocr(path: str) -> str:
    """Extract text from an image file."""
    img = cv2.imread(path)
    return pytesseract.image_to_string(img)

@tool
def read_excel(path: str) -> str:
    """Read contents of an Excel (.xlsx) file."""
    df = pd.read_excel(path)
    return df.to_string()

@tool
def reverse_text(text: str) -> str:
    """Reverse the text if it looks reversed."""
    reversed_candidate = text[::-1]
    if " " in reversed_candidate:
        return f"Reversed detected. Corrected: {reversed_candidate}"
    return text

@tool
def vector_search(query: str) -> str:
    """Search in example documents using vector similarity."""
    docs = [
        "Machine learning involves training algorithms on data.",
        "Neural networks are a part of deep learning.",
        "Supervised learning uses labeled datasets."
    ]
    embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    vectordb = Chroma.from_texts(docs, embedding=embed)
    results = vectordb.similarity_search(query, k=2)
    return "\n".join([r.page_content for r in results])

@tool
def math_calc(expression: str) -> str:
    """Evaluate a math expression safely."""
    allowed_chars = set('0123456789+-*/.() ')
    if not all(c in allowed_chars for c in expression):
        return "Error: Invalid characters in expression"
    try:
        return str(eval(expression, {"__builtins__": None}, {}))
    except Exception as e:
        return f"Error: {str(e)}"

@tool
def python_eval(code: str) -> str:
    """Evaluate basic Python code safely."""
    try:
        return str(eval(code, {"__builtins__": None}, {}))
    except Exception as e:
        return f"Error: {str(e)}"

# === Agent Initialization ===
tools = [
    wikipedia_search,
    web_search,
    youtube_transcript,
    image_ocr,
    read_excel,
    reverse_text,
    vector_search,
    math_calc,
    python_eval
]

prompt = ChatPromptTemplate.from_messages([
    SystemMessage(content=SYSTEM_PROMPT),
    ("human", "{input}"),
    ("ai", "{agent_scratchpad}")
])
model = InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct", provider="together")
agent = CodeAgent(
    model=model,
    tools=tools,
    max_steps=20
)
'''
agent = initialize_agent(
    tools=tools,
    llm=chat_model,
    agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
    verbose=True,
    memory=ConversationBufferWindowMemory(
        memory_key="chat_history",
        k=3,
        return_messages=True
    ),
    agent_kwargs={
        "system_message": SystemMessage(content=SYSTEM_PROMPT),
        "prompt": prompt
    },
    handle_parsing_errors=True
)
'''

# === Gradio Interface ===
def get_agent_response(question_text):
    try:
        response = agent.invoke({"input": question_text})
        return enforce_final_answer_format(response["output"])
    except Exception as e:
        return f"FINAL ANSWER: Error processing request: {str(e)}"

with gr.Blocks() as demo:
    gr.Markdown("# Strict Format Agent")
    gr.Markdown("This agent enforces FINAL ANSWER: format for all responses")
    
    with gr.Row():
        question = gr.Textbox(label="Your Question")
        output = gr.Textbox(label="Agent Response")
    
    submit = gr.Button("Submit")
    submit.click(fn=get_agent_response, inputs=question, outputs=output)

if __name__ == "__main__":
    demo.launch()