Spaces:
Running
Running
Update agent.py
Browse files
agent.py
CHANGED
@@ -1,108 +1,261 @@
|
|
|
|
1 |
import os
|
2 |
-
import subprocess
|
3 |
import random
|
4 |
-
|
5 |
import gradio as gr
|
6 |
-
|
7 |
-
from
|
8 |
-
from
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
35 |
return prompt
|
36 |
|
37 |
-
def run_agent(instruction, history):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
prompt = format_prompt_var(instruction, history)
|
39 |
response = ""
|
40 |
-
|
41 |
-
response += chunk
|
42 |
-
if "\n\n### Instruction:" in chunk:
|
43 |
-
break
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
return response, response_actions
|
51 |
|
52 |
-
def generate(prompt, history, temperature):
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
generate_kwargs = {
|
55 |
"temperature": temperature,
|
56 |
-
"max_new_tokens":
|
57 |
"top_p": 0.95,
|
58 |
"repetition_penalty": 1.0,
|
59 |
"do_sample": True,
|
60 |
"seed": seed,
|
61 |
}
|
62 |
-
formatted_prompt = format_prompt_var(
|
63 |
-
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
|
64 |
-
output = ""
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
for response in stream:
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
yield output
|
|
|
|
|
|
|
69 |
|
70 |
-
def
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
chatbot = gr.Chatbot()
|
76 |
-
|
77 |
-
|
|
|
|
|
78 |
txt = gr.Textbox(show_label=False, placeholder="Type something...")
|
79 |
btn = gr.Button("Send", variant="primary")
|
80 |
|
81 |
-
|
82 |
-
txt.
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
with
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
return block
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
def _clear_history(history):
|
96 |
-
return [], []
|
97 |
-
|
98 |
-
# Exportable functions and variables
|
99 |
-
__all__ = [
|
100 |
-
"run_agent",
|
101 |
-
"create_interface",
|
102 |
-
"format_prompt_var",
|
103 |
-
"generate",
|
104 |
-
# "MAX_HISTORY",
|
105 |
-
"client",
|
106 |
-
"VERBOSE",
|
107 |
-
"date_time_str",
|
108 |
-
]
|
|
|
1 |
+
|
2 |
import os
|
|
|
3 |
import random
|
4 |
+
import logging
|
5 |
import gradio as gr
|
6 |
+
import asyncio
|
7 |
+
from typing import List, Tuple, Generator, Any
|
8 |
+
from inference_client import InferenceClient # Adjust the import as needed
|
9 |
+
|
10 |
+
# Set up logging to capture errors and warnings.
|
11 |
+
logging.basicConfig(
|
12 |
+
level=logging.INFO,
|
13 |
+
filename='chatbot.log',
|
14 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
15 |
+
)
|
16 |
+
|
17 |
+
# Encapsulated configuration to avoid global variable pitfalls.
|
18 |
+
class ChatbotConfig:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
max_history: int = 100,
|
22 |
+
verbose: bool = True,
|
23 |
+
max_iterations: int = 1000,
|
24 |
+
max_new_tokens: int = 256,
|
25 |
+
default_seed: int = None
|
26 |
+
):
|
27 |
+
self.max_history = max_history
|
28 |
+
self.verbose = verbose
|
29 |
+
self.max_iterations = max_iterations
|
30 |
+
self.max_new_tokens = max_new_tokens
|
31 |
+
self.default_seed = default_seed or random.randint(1, 2**32 - 1)
|
32 |
|
33 |
+
# Global configuration instance.
|
34 |
+
config = ChatbotConfig()
|
35 |
|
36 |
+
# Externalize prompts into a dictionary, optionally overridden by environment variables.
|
37 |
+
PROMPTS = {
|
38 |
+
"ACTION_PROMPT": os.environ.get("ACTION_PROMPT", "action prompt"),
|
39 |
+
"ADD_PROMPT": os.environ.get("ADD_PROMPT", "add prompt"),
|
40 |
+
"COMPRESS_HISTORY_PROMPT": os.environ.get("COMPRESS_HISTORY_PROMPT", "compress history prompt"),
|
41 |
+
"LOG_PROMPT": os.environ.get("LOG_PROMPT", "log prompt"),
|
42 |
+
"LOG_RESPONSE": os.environ.get("LOG_RESPONSE", "log response"),
|
43 |
+
"MODIFY_PROMPT": os.environ.get("MODIFY_PROMPT", "modify prompt"),
|
44 |
+
"PREFIX": os.environ.get("PREFIX", "prefix"),
|
45 |
+
"SEARCH_QUERY": os.environ.get("SEARCH_QUERY", "search query"),
|
46 |
+
"READ_PROMPT": os.environ.get("READ_PROMPT", "read prompt"),
|
47 |
+
"TASK_PROMPT": os.environ.get("TASK_PROMPT", "task prompt"),
|
48 |
+
"UNDERSTAND_TEST_RESULTS_PROMPT": os.environ.get("UNDERSTAND_TEST_RESULTS_PROMPT", "understand test results prompt")
|
49 |
+
}
|
50 |
+
|
51 |
+
# Instantiate the AI client.
|
52 |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
53 |
|
54 |
+
def format_prompt_var(message: str, history: List[str]) -> str:
|
55 |
+
"""
|
56 |
+
Format the provided message and conversation history into the required prompt format.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
message (str): The current instruction/message.
|
60 |
+
history (List[str]): List of previous conversation entries.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
str: A formatted prompt string.
|
64 |
+
|
65 |
+
Raises:
|
66 |
+
TypeError: If message is not a string or any history entry is not a string.
|
67 |
+
"""
|
68 |
+
if not isinstance(message, str):
|
69 |
+
raise TypeError("The instruction message must be a string.")
|
70 |
+
if not all(isinstance(item, str) for item in history):
|
71 |
+
raise TypeError("All items in history must be strings.")
|
72 |
+
|
73 |
+
history_text = "\n".join(history) if history else "No previous conversation."
|
74 |
+
prompt = f"\n### Instruction:\n{message}\n### History:\n{history_text}"
|
75 |
return prompt
|
76 |
|
77 |
+
def run_agent(instruction: str, history: List[str]) -> Tuple[str, List[str]]:
|
78 |
+
"""
|
79 |
+
Run the AI agent with the given instruction and conversation history.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
instruction (str): The user instruction.
|
83 |
+
history (List[str]): The conversation history.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
Tuple[str, List[str]]: A tuple containing the full AI response and a list of extracted actions.
|
87 |
+
|
88 |
+
Raises:
|
89 |
+
TypeError: If inputs are of invalid type.
|
90 |
+
"""
|
91 |
+
if not isinstance(instruction, str):
|
92 |
+
raise TypeError("Instruction must be a string.")
|
93 |
+
if not isinstance(history, list) or not all(isinstance(item, str) for item in history):
|
94 |
+
raise TypeError("History must be a list of strings.")
|
95 |
+
|
96 |
prompt = format_prompt_var(instruction, history)
|
97 |
response = ""
|
98 |
+
iterations = 0
|
|
|
|
|
|
|
99 |
|
100 |
+
try:
|
101 |
+
for chunk in generate(prompt, history[-config.max_history:], temperature=0.7):
|
102 |
+
response += chunk
|
103 |
+
iterations += 1
|
104 |
+
if "\n\n### Instruction:" in chunk or iterations >= config.max_iterations:
|
105 |
+
break
|
106 |
+
except Exception as e:
|
107 |
+
logging.error("Error in run_agent: %s", e)
|
108 |
+
response += f"\n[Error in run_agent: {e}]"
|
109 |
|
110 |
+
# Extract actions from the response.
|
111 |
+
response_actions = []
|
112 |
+
for line in response.strip().split("\n"):
|
113 |
+
if line.startswith("action:"):
|
114 |
+
response_actions.append(line.replace("action: ", ""))
|
115 |
+
|
116 |
return response, response_actions
|
117 |
|
118 |
+
def generate(prompt: str, history: List[str], temperature: float) -> Generator[str, None, None]:
|
119 |
+
"""
|
120 |
+
Generate text from the AI model using the formatted prompt.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
prompt (str): The input prompt.
|
124 |
+
history (List[str]): Recent conversation history.
|
125 |
+
temperature (float): Sampling temperature.
|
126 |
+
|
127 |
+
Yields:
|
128 |
+
str: Incremental output from the text-generation stream.
|
129 |
+
"""
|
130 |
+
seed = random.randint(1, 2**32 - 1) if config.default_seed is None else config.default_seed
|
131 |
generate_kwargs = {
|
132 |
"temperature": temperature,
|
133 |
+
"max_new_tokens": config.max_new_tokens,
|
134 |
"top_p": 0.95,
|
135 |
"repetition_penalty": 1.0,
|
136 |
"do_sample": True,
|
137 |
"seed": seed,
|
138 |
}
|
139 |
+
formatted_prompt = format_prompt_var(prompt, history)
|
|
|
|
|
140 |
|
141 |
+
try:
|
142 |
+
stream = client.text_generation(
|
143 |
+
formatted_prompt,
|
144 |
+
**generate_kwargs,
|
145 |
+
stream=True,
|
146 |
+
details=True,
|
147 |
+
return_full_text=False
|
148 |
+
)
|
149 |
+
except Exception as e:
|
150 |
+
logging.error("Error during text_generation call: %s", e)
|
151 |
+
yield f"[Error during text_generation call: {e}]"
|
152 |
+
return
|
153 |
+
|
154 |
+
output = ""
|
155 |
+
iterations = 0
|
156 |
for response in stream:
|
157 |
+
iterations += 1
|
158 |
+
try:
|
159 |
+
output += response.token.text
|
160 |
+
except AttributeError as ae:
|
161 |
+
logging.error("Malformed response token: %s", ae)
|
162 |
+
yield f"[Malformed response token: {ae}]"
|
163 |
+
break
|
164 |
yield output
|
165 |
+
if iterations >= config.max_iterations:
|
166 |
+
yield "\n[Response truncated due to length limitations]"
|
167 |
+
break
|
168 |
|
169 |
+
async def async_run_agent(instruction: str, history: List[str]) -> Tuple[str, List[str]]:
|
170 |
+
"""
|
171 |
+
Asynchronous wrapper to run the agent in a separate thread.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
instruction (str): The instruction for the AI.
|
175 |
+
history (List[str]): The conversation history.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
Tuple[str, List[str]]: The response and extracted actions.
|
179 |
+
"""
|
180 |
+
return await asyncio.to_thread(run_agent, instruction, history)
|
181 |
|
182 |
+
def clear_conversation() -> List[str]:
|
183 |
+
"""
|
184 |
+
Clear the conversation history.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
List[str]: An empty conversation history.
|
188 |
+
"""
|
189 |
+
return []
|
190 |
|
191 |
+
def update_chatbot_styles(history: List[Any]) -> Any:
|
192 |
+
"""
|
193 |
+
Update the chatbot display styles based on the number of messages.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
history (List[Any]): The current conversation history.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
Update object for Gradio Chatbot.
|
200 |
+
"""
|
201 |
+
num_messages = sum(1 for item in history if isinstance(item, tuple))
|
202 |
+
return gr.Chatbot.update({"num_messages": num_messages})
|
203 |
+
|
204 |
+
def update_max_history(value: int) -> int:
|
205 |
+
"""
|
206 |
+
Update the max_history in configuration.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
value (int): New maximum history value.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
int: The updated max_history.
|
213 |
+
"""
|
214 |
+
config.max_history = int(value)
|
215 |
+
return config.max_history
|
216 |
+
|
217 |
+
def create_interface() -> gr.Blocks:
|
218 |
+
"""
|
219 |
+
Create and return the Gradio interface for the chatbot application.
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
gr.Blocks: The Gradio Blocks object representing the UI.
|
223 |
+
"""
|
224 |
+
block = gr.Blocks()
|
225 |
chatbot = gr.Chatbot()
|
226 |
+
|
227 |
+
with block:
|
228 |
+
gr.Markdown("## Expert Web Developer Assistant")
|
229 |
+
with gr.Tab("Conversation"):
|
230 |
txt = gr.Textbox(show_label=False, placeholder="Type something...")
|
231 |
btn = gr.Button("Send", variant="primary")
|
232 |
|
233 |
+
# When text is submitted, run the agent asynchronously.
|
234 |
+
txt.submit(
|
235 |
+
async_run_agent,
|
236 |
+
inputs=[txt, chatbot],
|
237 |
+
outputs=[chatbot, None]
|
238 |
+
)
|
239 |
+
# Clear conversation history and update chatbot UI.
|
240 |
+
txt.clear(fn=clear_conversation, outputs=chatbot).then(
|
241 |
+
update_chatbot_styles, chatbot, chatbot
|
242 |
+
)
|
243 |
+
btn.click(fn=clear_conversation, outputs=chatbot).then(
|
244 |
+
update_chatbot_styles, chatbot, chatbot
|
245 |
+
)
|
246 |
|
247 |
+
with gr.Tab("Settings"):
|
248 |
+
max_history_slider = gr.Slider(
|
249 |
+
minimum=1, maximum=100, step=1,
|
250 |
+
label="Max history",
|
251 |
+
value=config.max_history
|
252 |
+
)
|
253 |
+
max_history_slider.change(
|
254 |
+
update_max_history, max_history_slider, max_history_slider
|
255 |
+
)
|
256 |
|
257 |
return block
|
258 |
|
259 |
+
if __name__ == "__main__":
|
260 |
+
interface = create_interface()
|
261 |
+
interface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|