whackthejacker commited on
Commit
4b1134a
·
verified ·
1 Parent(s): c2c5916

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +225 -72
agent.py CHANGED
@@ -1,108 +1,261 @@
 
1
  import os
2
- import subprocess
3
  import random
4
- from huggingface_hub import InferenceClient
5
  import gradio as gr
6
- from safe_search import safe_search
7
- from i_search import google
8
- from i_search import i_search as i_s
9
- from datetime import datetime
10
- from utils import parse_action, parse_file_content, read_python_module_structure
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- now = datetime.now()
13
- date_time_str = now.strftime("%Y-%m-%d %H:%M:%S")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
16
 
17
- VERBOSE = True
18
- MAX_HISTORY = 100
19
-
20
- # Prompts
21
- ACTION_PROMPT = "action prompt"
22
- ADD_PROMPT = "add prompt"
23
- COMPRESS_HISTORY_PROMPT = "compress history prompt"
24
- LOG_PROMPT = "log prompt"
25
- LOG_RESPONSE = "log response"
26
- MODIFY_PROMPT = "modify prompt"
27
- PREFIX = "prefix"
28
- SEARCH_QUERY = "search query"
29
- READ_PROMPT = "read prompt"
30
- TASK_PROMPT = "task prompt"
31
- UNDERSTAND_TEST_RESULTS_PROMPT = "understand test results prompt"
32
-
33
- def format_prompt_var(message, history):
34
- prompt = "\n### Instruction:\n{}\n### History:\n{}".format(message, '\n'.join(history))
 
 
 
35
  return prompt
36
 
37
- def run_agent(instruction, history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  prompt = format_prompt_var(instruction, history)
39
  response = ""
40
- for chunk in generate(prompt, history[-MAX_HISTORY:], temperature=0.7):
41
- response += chunk
42
- if "\n\n### Instruction:" in chunk:
43
- break
44
 
45
- response_actions = []
46
- for line in response.strip().split('\n'):
47
- if line.startswith('action:'):
48
- response_actions.append((line.replace('action: ', '')))
 
 
 
 
 
49
 
 
 
 
 
 
 
50
  return response, response_actions
51
 
52
- def generate(prompt, history, temperature):
53
- seed = random.randint(1, 1111111111111111)
 
 
 
 
 
 
 
 
 
 
 
54
  generate_kwargs = {
55
  "temperature": temperature,
56
- "max_new_tokens": 256,
57
  "top_p": 0.95,
58
  "repetition_penalty": 1.0,
59
  "do_sample": True,
60
  "seed": seed,
61
  }
62
- formatted_prompt = format_prompt_var(f"{prompt}", history)
63
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
64
- output = ""
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  for response in stream:
67
- output += response.token.text
 
 
 
 
 
 
68
  yield output
 
 
 
69
 
70
- def create_interface():
71
- global MAX_HISTORY
 
 
 
 
 
 
 
 
 
 
72
 
73
- block = gr.Blocks()
 
 
 
 
 
 
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  chatbot = gr.Chatbot()
76
- with block.title("Expert Web Developer Assistant"):
77
- with block.tab("Conversation"):
 
 
78
  txt = gr.Textbox(show_label=False, placeholder="Type something...")
79
  btn = gr.Button("Send", variant="primary")
80
 
81
- txt.submit(run_agent, inputs=[txt, chatbot], outputs=[chatbot, None])
82
- txt.clear(None, [txt, chatbot]).then(_clear_history, chatbot, _update_chatbot_styles)
83
- btn.click(_clear_history, chatbot, _update_chatbot_styles)
 
 
 
 
 
 
 
 
 
 
84
 
85
- with block.tab("Settings"):
86
- MAX_HISTORY_slider = gr.Slider(minimum=1, maximum=100, step=1, label="Max history", value=MAX_HISTORY)
87
- MAX_HISTORY_slider.change(lambda x: setattr(block, "MAX_HISTORY", int(x)), MAX_HISTORY_slider)
 
 
 
 
 
 
88
 
89
  return block
90
 
91
- def _update_chatbot_styles(history):
92
- num_messages = sum([1 for item in history if isinstance(item, tuple)])
93
- gr.Chatbot.update({"num_messages": num_messages})
94
-
95
- def _clear_history(history):
96
- return [], []
97
-
98
- # Exportable functions and variables
99
- __all__ = [
100
- "run_agent",
101
- "create_interface",
102
- "format_prompt_var",
103
- "generate",
104
- # "MAX_HISTORY",
105
- "client",
106
- "VERBOSE",
107
- "date_time_str",
108
- ]
 
1
+
2
  import os
 
3
  import random
4
+ import logging
5
  import gradio as gr
6
+ import asyncio
7
+ from typing import List, Tuple, Generator, Any
8
+ from inference_client import InferenceClient # Adjust the import as needed
9
+
10
+ # Set up logging to capture errors and warnings.
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ filename='chatbot.log',
14
+ format='%(asctime)s - %(levelname)s - %(message)s'
15
+ )
16
+
17
+ # Encapsulated configuration to avoid global variable pitfalls.
18
+ class ChatbotConfig:
19
+ def __init__(
20
+ self,
21
+ max_history: int = 100,
22
+ verbose: bool = True,
23
+ max_iterations: int = 1000,
24
+ max_new_tokens: int = 256,
25
+ default_seed: int = None
26
+ ):
27
+ self.max_history = max_history
28
+ self.verbose = verbose
29
+ self.max_iterations = max_iterations
30
+ self.max_new_tokens = max_new_tokens
31
+ self.default_seed = default_seed or random.randint(1, 2**32 - 1)
32
 
33
+ # Global configuration instance.
34
+ config = ChatbotConfig()
35
 
36
+ # Externalize prompts into a dictionary, optionally overridden by environment variables.
37
+ PROMPTS = {
38
+ "ACTION_PROMPT": os.environ.get("ACTION_PROMPT", "action prompt"),
39
+ "ADD_PROMPT": os.environ.get("ADD_PROMPT", "add prompt"),
40
+ "COMPRESS_HISTORY_PROMPT": os.environ.get("COMPRESS_HISTORY_PROMPT", "compress history prompt"),
41
+ "LOG_PROMPT": os.environ.get("LOG_PROMPT", "log prompt"),
42
+ "LOG_RESPONSE": os.environ.get("LOG_RESPONSE", "log response"),
43
+ "MODIFY_PROMPT": os.environ.get("MODIFY_PROMPT", "modify prompt"),
44
+ "PREFIX": os.environ.get("PREFIX", "prefix"),
45
+ "SEARCH_QUERY": os.environ.get("SEARCH_QUERY", "search query"),
46
+ "READ_PROMPT": os.environ.get("READ_PROMPT", "read prompt"),
47
+ "TASK_PROMPT": os.environ.get("TASK_PROMPT", "task prompt"),
48
+ "UNDERSTAND_TEST_RESULTS_PROMPT": os.environ.get("UNDERSTAND_TEST_RESULTS_PROMPT", "understand test results prompt")
49
+ }
50
+
51
+ # Instantiate the AI client.
52
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
53
 
54
+ def format_prompt_var(message: str, history: List[str]) -> str:
55
+ """
56
+ Format the provided message and conversation history into the required prompt format.
57
+
58
+ Args:
59
+ message (str): The current instruction/message.
60
+ history (List[str]): List of previous conversation entries.
61
+
62
+ Returns:
63
+ str: A formatted prompt string.
64
+
65
+ Raises:
66
+ TypeError: If message is not a string or any history entry is not a string.
67
+ """
68
+ if not isinstance(message, str):
69
+ raise TypeError("The instruction message must be a string.")
70
+ if not all(isinstance(item, str) for item in history):
71
+ raise TypeError("All items in history must be strings.")
72
+
73
+ history_text = "\n".join(history) if history else "No previous conversation."
74
+ prompt = f"\n### Instruction:\n{message}\n### History:\n{history_text}"
75
  return prompt
76
 
77
+ def run_agent(instruction: str, history: List[str]) -> Tuple[str, List[str]]:
78
+ """
79
+ Run the AI agent with the given instruction and conversation history.
80
+
81
+ Args:
82
+ instruction (str): The user instruction.
83
+ history (List[str]): The conversation history.
84
+
85
+ Returns:
86
+ Tuple[str, List[str]]: A tuple containing the full AI response and a list of extracted actions.
87
+
88
+ Raises:
89
+ TypeError: If inputs are of invalid type.
90
+ """
91
+ if not isinstance(instruction, str):
92
+ raise TypeError("Instruction must be a string.")
93
+ if not isinstance(history, list) or not all(isinstance(item, str) for item in history):
94
+ raise TypeError("History must be a list of strings.")
95
+
96
  prompt = format_prompt_var(instruction, history)
97
  response = ""
98
+ iterations = 0
 
 
 
99
 
100
+ try:
101
+ for chunk in generate(prompt, history[-config.max_history:], temperature=0.7):
102
+ response += chunk
103
+ iterations += 1
104
+ if "\n\n### Instruction:" in chunk or iterations >= config.max_iterations:
105
+ break
106
+ except Exception as e:
107
+ logging.error("Error in run_agent: %s", e)
108
+ response += f"\n[Error in run_agent: {e}]"
109
 
110
+ # Extract actions from the response.
111
+ response_actions = []
112
+ for line in response.strip().split("\n"):
113
+ if line.startswith("action:"):
114
+ response_actions.append(line.replace("action: ", ""))
115
+
116
  return response, response_actions
117
 
118
+ def generate(prompt: str, history: List[str], temperature: float) -> Generator[str, None, None]:
119
+ """
120
+ Generate text from the AI model using the formatted prompt.
121
+
122
+ Args:
123
+ prompt (str): The input prompt.
124
+ history (List[str]): Recent conversation history.
125
+ temperature (float): Sampling temperature.
126
+
127
+ Yields:
128
+ str: Incremental output from the text-generation stream.
129
+ """
130
+ seed = random.randint(1, 2**32 - 1) if config.default_seed is None else config.default_seed
131
  generate_kwargs = {
132
  "temperature": temperature,
133
+ "max_new_tokens": config.max_new_tokens,
134
  "top_p": 0.95,
135
  "repetition_penalty": 1.0,
136
  "do_sample": True,
137
  "seed": seed,
138
  }
139
+ formatted_prompt = format_prompt_var(prompt, history)
 
 
140
 
141
+ try:
142
+ stream = client.text_generation(
143
+ formatted_prompt,
144
+ **generate_kwargs,
145
+ stream=True,
146
+ details=True,
147
+ return_full_text=False
148
+ )
149
+ except Exception as e:
150
+ logging.error("Error during text_generation call: %s", e)
151
+ yield f"[Error during text_generation call: {e}]"
152
+ return
153
+
154
+ output = ""
155
+ iterations = 0
156
  for response in stream:
157
+ iterations += 1
158
+ try:
159
+ output += response.token.text
160
+ except AttributeError as ae:
161
+ logging.error("Malformed response token: %s", ae)
162
+ yield f"[Malformed response token: {ae}]"
163
+ break
164
  yield output
165
+ if iterations >= config.max_iterations:
166
+ yield "\n[Response truncated due to length limitations]"
167
+ break
168
 
169
+ async def async_run_agent(instruction: str, history: List[str]) -> Tuple[str, List[str]]:
170
+ """
171
+ Asynchronous wrapper to run the agent in a separate thread.
172
+
173
+ Args:
174
+ instruction (str): The instruction for the AI.
175
+ history (List[str]): The conversation history.
176
+
177
+ Returns:
178
+ Tuple[str, List[str]]: The response and extracted actions.
179
+ """
180
+ return await asyncio.to_thread(run_agent, instruction, history)
181
 
182
+ def clear_conversation() -> List[str]:
183
+ """
184
+ Clear the conversation history.
185
+
186
+ Returns:
187
+ List[str]: An empty conversation history.
188
+ """
189
+ return []
190
 
191
+ def update_chatbot_styles(history: List[Any]) -> Any:
192
+ """
193
+ Update the chatbot display styles based on the number of messages.
194
+
195
+ Args:
196
+ history (List[Any]): The current conversation history.
197
+
198
+ Returns:
199
+ Update object for Gradio Chatbot.
200
+ """
201
+ num_messages = sum(1 for item in history if isinstance(item, tuple))
202
+ return gr.Chatbot.update({"num_messages": num_messages})
203
+
204
+ def update_max_history(value: int) -> int:
205
+ """
206
+ Update the max_history in configuration.
207
+
208
+ Args:
209
+ value (int): New maximum history value.
210
+
211
+ Returns:
212
+ int: The updated max_history.
213
+ """
214
+ config.max_history = int(value)
215
+ return config.max_history
216
+
217
+ def create_interface() -> gr.Blocks:
218
+ """
219
+ Create and return the Gradio interface for the chatbot application.
220
+
221
+ Returns:
222
+ gr.Blocks: The Gradio Blocks object representing the UI.
223
+ """
224
+ block = gr.Blocks()
225
  chatbot = gr.Chatbot()
226
+
227
+ with block:
228
+ gr.Markdown("## Expert Web Developer Assistant")
229
+ with gr.Tab("Conversation"):
230
  txt = gr.Textbox(show_label=False, placeholder="Type something...")
231
  btn = gr.Button("Send", variant="primary")
232
 
233
+ # When text is submitted, run the agent asynchronously.
234
+ txt.submit(
235
+ async_run_agent,
236
+ inputs=[txt, chatbot],
237
+ outputs=[chatbot, None]
238
+ )
239
+ # Clear conversation history and update chatbot UI.
240
+ txt.clear(fn=clear_conversation, outputs=chatbot).then(
241
+ update_chatbot_styles, chatbot, chatbot
242
+ )
243
+ btn.click(fn=clear_conversation, outputs=chatbot).then(
244
+ update_chatbot_styles, chatbot, chatbot
245
+ )
246
 
247
+ with gr.Tab("Settings"):
248
+ max_history_slider = gr.Slider(
249
+ minimum=1, maximum=100, step=1,
250
+ label="Max history",
251
+ value=config.max_history
252
+ )
253
+ max_history_slider.change(
254
+ update_max_history, max_history_slider, max_history_slider
255
+ )
256
 
257
  return block
258
 
259
+ if __name__ == "__main__":
260
+ interface = create_interface()
261
+ interface.launch()