File size: 2,961 Bytes
1d91d99 acce42e 6aae614 f31224c 9b5b26a 1d91d99 9b5b26a f31224c 9b5b26a f31224c 9b5b26a f31224c 9b5b26a f31224c 8c01ffb 6aae614 ae7a494 e121372 1d91d99 13d500a 1d91d99 1e2135d 9b5b26a 8c01ffb 8fe992b 1d91d99 8c01ffb 1e2135d 8fe992b 1d91d99 9b5b26a 1d91d99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from smolagents import CodeAgent, HfApiModel, tool
import yaml
from tools.final_answer import FinalAnswerTool
import wikipedia
from Gradio_UI import GradioUI
# Wikipedia search tool
@tool
def wikipedia_search(query: str, sentences: int = 2) -> str:
"""Search Wikipedia and return a short summary.
Args:
query: The search term for Wikipedia.
sentences: The number of sentences to return from the summary.
"""
try:
summary = wikipedia.summary(query, sentences=sentences)
return summary
except wikipedia.exceptions.DisambiguationError as e:
return f"Multiple results found: {', '.join(e.options[:5])}..."
except wikipedia.exceptions.PageError:
return "No Wikipedia page found for that query."
except Exception as e:
return f"An error occurred: {str(e)}"
final_answer = FinalAnswerTool()
model = HfApiModel(
max_tokens=2096,
temperature=0.5,
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
custom_role_conversions=None,
)
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[final_answer, wikipedia_search],
max_steps=6,
verbosity_level=1,
grammar=None,
planning_interval=None,
name=None,
description=None,
prompt_templates=prompt_templates
)
# Custom GradioUI that resets agent context after 4 user messages/responses.
class CustomGradioUI(GradioUI):
def __init__(self, agent, max_messages=4):
super().__init__(agent)
self.max_messages = max_messages
self.message_count = 0
def process_user_input(self, user_input):
"""
Process a user message, call the agent, and then reset context
if the number of interactions reaches max_messages.
"""
# Get response from the agent
response = self.agent.run(user_input)
self.message_count += 1
# Check if we've reached the limit of messages before reset.
if self.message_count >= self.max_messages:
# Reset the agent's context.
if hasattr(self.agent, 'reset'):
self.agent.reset() # Use agent's built-in reset method if available.
elif hasattr(self.agent, 'conversation_history'):
self.agent.conversation_history.clear() # Clear conversation history if accessible.
self.message_count = 0 # Reset our counter.
return response
def launch(self):
"""
Override launch if needed to ensure our process_user_input method is used.
This assumes that the base GradioUI calls a method we can override.
"""
# If GradioUI accepts a custom function for processing input, you might pass self.process_user_input.
# Otherwise, ensure that the UI calls this method when handling a user message.
super().launch()
# Launch the custom UI.
CustomGradioUI(agent).launch()
|