File size: 6,480 Bytes
e1392d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# test_agent.py

import logging
from twisted.internet import reactor, defer, threads
from agent import AutonomousWebAgent
from ToTSearch import ToTSearch

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Initialize the logger
logger = logging.getLogger(__name__)

# Suppress detailed logs for some libraries (like Scrapy or Transformers)
logging.getLogger('scrapy').setLevel(logging.ERROR)
logging.getLogger('transformers').setLevel(logging.ERROR)
logging.getLogger('twisted').setLevel(logging.ERROR)

import warnings
warnings.filterwarnings("ignore", category=FutureWarning)


class TestAgent:
    def __init__(self):
        # Initialize the AutonomousWebAgent
        state_size = 7  # word_count, link_count, header_count, semantic_similarity, image_count, script_count, css_count
        action_size = 3  # 0: Click Link, 1: Summarize, 2: RAG Generate
        num_options = 3  # 0: Search, 1: Summarize, 2: RAG Generate

        self.agent = AutonomousWebAgent(
            state_size=state_size,
            action_size=action_size,
            num_options=num_options,
            hidden_size=64,
            learning_rate=0.001,
            gamma=0.99,
            epsilon=1.0,
            epsilon_decay=0.995,
            epsilon_min=0.01,
            knowledge_base_path='knowledge_base.json'
        )

        # Initialize ToTSearch with the agent
        self.tot_search = ToTSearch(self.agent)

        # Few-shot examples for Tree of Thoughts
        self.few_shot_examples = [
            {
                "query": "What are the effects of climate change on biodiversity?",
                "thoughts": [
                    "Loss of habitats due to rising sea levels and changing temperatures",
                    "Disruption of ecosystems and food chains",
                    "Increased extinction rates for vulnerable species"
                ],
                "answer": "Climate change significantly impacts biodiversity through habitat loss, ecosystem disruption, and increased extinction rates. Rising temperatures and sea levels alter habitats, forcing species to adapt or migrate. This disrupts established ecosystems and food chains. Species unable to adapt quickly face a higher risk of extinction, particularly those with specialized habitats or limited ranges."
            },
            {
                "query": "How can we promote sustainable energy adoption?",
                "thoughts": [
                    "Government policies and incentives",
                    "Public awareness and education campaigns",
                    "Technological advancements and cost reduction"
                ],
                "answer": "Promoting sustainable energy adoption requires a multi-faceted approach. Government policies and incentives can encourage both businesses and individuals to switch to renewable sources. Public awareness and education campaigns help people understand the importance and benefits of sustainable energy. Continued technological advancements and cost reductions make sustainable energy more accessible and economically viable for widespread adoption."
            }
        ]

    @defer.inlineCallbacks
    def process_query(self, query, is_few_shot=False):
        logger.info(f"Processing query: {query}")
        try:
            if is_few_shot:
                few_shot_prompt = self.create_few_shot_prompt(query)
                enhanced_query = f"{few_shot_prompt}\n\nQuery: {query}"
                logger.debug(f"Enhanced query for few-shot learning: {enhanced_query[:100]}...")
                final_answer = yield self.tot_search.search(enhanced_query)
            else:
                final_answer = yield self.tot_search.search(query)
            
            logger.info(f"Final answer for '{query}':")
            logger.info(final_answer)
            
            yield self.agent.add_document_to_kb(title=f"ToT Search Result: {query}", content=final_answer)
            
            yield self.agent.replay_worker(batch_size=32)
            yield self.agent.replay_manager(batch_size=32)
            
            return final_answer
        except Exception as e:
            logger.error(f"Error processing query '{query}': {str(e)}", exc_info=True)
            return f"An error occurred: {str(e)}"
        
    def create_few_shot_prompt(self, query):
        prompt = "Here are some examples of how to approach queries using a Tree of Thoughts:\n\n"
        for example in self.few_shot_examples:
            prompt += f"Query: {example['query']}\n"
            prompt += "Thoughts:\n"
            for thought in example['thoughts']:
                prompt += f"- {thought}\n"
            prompt += f"Answer: {example['answer']}\n\n"
        prompt += f"Now, let's approach the following query in a similar manner:\n\nQuery: {query}\n"
        return prompt

    def save_models(self):
        self.agent.save_worker_model("worker_model_final.pth")
        self.agent.save_manager_model("manager_model_final.pth")
        logger.info("Agent models saved.")


def get_user_input():
    return input("Enter your query (or 'quit' to exit): ")


@defer.inlineCallbacks
def run_test_session():
    test_agent = TestAgent()
    
    logger.info("Starting few-shot learning phase...")
    for example in test_agent.few_shot_examples:
        logger.info(f"Processing few-shot example: {example['query']}")
        try:
            yield test_agent.process_query(example['query'], is_few_shot=True)
        except Exception as e:
            logger.error(f"Error in few-shot learning: {str(e)}", exc_info=True)
    
    logger.info("Few-shot learning phase completed. Starting interactive session.")
    
    while True:
        query = yield threads.deferToThread(get_user_input)
        
        if query.lower() == 'quit':
            break
        
        try:
            answer = yield test_agent.process_query(query)
            print("\nAgent's response:")
            print(answer)
            print("\n" + "-"*50 + "\n")
        except Exception as e:
            logger.error(f"Error in interactive session: {str(e)}", exc_info=True)
    
    test_agent.save_models()
    reactor.stop()


if __name__ == "__main__":
    reactor.callWhenRunning(run_test_session)
    reactor.run()