File size: 2,961 Bytes
1d91d99
acce42e
6aae614
f31224c
9b5b26a
 
1d91d99
9b5b26a
f31224c
 
 
9b5b26a
f31224c
 
9b5b26a
 
f31224c
 
 
 
 
 
9b5b26a
f31224c
8c01ffb
6aae614
ae7a494
e121372
1d91d99
 
 
 
13d500a
1d91d99
1e2135d
 
9b5b26a
8c01ffb
8fe992b
1d91d99
8c01ffb
 
 
 
 
 
1e2135d
8fe992b
 
1d91d99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b5b26a
1d91d99
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from smolagents import CodeAgent, HfApiModel, tool
import yaml
from tools.final_answer import FinalAnswerTool
import wikipedia
from Gradio_UI import GradioUI

# Wikipedia search tool
@tool
def wikipedia_search(query: str, sentences: int = 2) -> str:
    """Search Wikipedia and return a short summary.
    
    Args:
        query: The search term for Wikipedia.
        sentences: The number of sentences to return from the summary.
    """
    try:
        summary = wikipedia.summary(query, sentences=sentences)
        return summary
    except wikipedia.exceptions.DisambiguationError as e:
        return f"Multiple results found: {', '.join(e.options[:5])}..."
    except wikipedia.exceptions.PageError:
        return "No Wikipedia page found for that query."
    except Exception as e:
        return f"An error occurred: {str(e)}"

final_answer = FinalAnswerTool()

model = HfApiModel(
    max_tokens=2096,
    temperature=0.5,
    model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
    custom_role_conversions=None,
)

with open("prompts.yaml", 'r') as stream:
    prompt_templates = yaml.safe_load(stream)
    
agent = CodeAgent(
    model=model,
    tools=[final_answer, wikipedia_search],
    max_steps=6,
    verbosity_level=1,
    grammar=None,
    planning_interval=None,
    name=None,
    description=None,
    prompt_templates=prompt_templates
)

# Custom GradioUI that resets agent context after 4 user messages/responses.
class CustomGradioUI(GradioUI):
    def __init__(self, agent, max_messages=4):
        super().__init__(agent)
        self.max_messages = max_messages
        self.message_count = 0

    def process_user_input(self, user_input):
        """
        Process a user message, call the agent, and then reset context
        if the number of interactions reaches max_messages.
        """
        # Get response from the agent
        response = self.agent.run(user_input)
        self.message_count += 1

        # Check if we've reached the limit of messages before reset.
        if self.message_count >= self.max_messages:
            # Reset the agent's context.
            if hasattr(self.agent, 'reset'):
                self.agent.reset()  # Use agent's built-in reset method if available.
            elif hasattr(self.agent, 'conversation_history'):
                self.agent.conversation_history.clear()  # Clear conversation history if accessible.
            self.message_count = 0  # Reset our counter.
        return response

    def launch(self):
        """
        Override launch if needed to ensure our process_user_input method is used.
        This assumes that the base GradioUI calls a method we can override.
        """
        # If GradioUI accepts a custom function for processing input, you might pass self.process_user_input.
        # Otherwise, ensure that the UI calls this method when handling a user message.
        super().launch()

# Launch the custom UI.
CustomGradioUI(agent).launch()