SOP_Generation-multi / Environment /base_environment.py
callanwu's picture
init
200916c
from utils import get_relevant_history, get_embedding
import torch
from LLM.base_LLM import *
from Memory import Memory
from Prompt import *
import json
class Environment:
"""
The place where the agent activities, responsible for storing some shared memories
"""
def __init__(self, config) -> None:
self.shared_memory = {"long_term_memory": [], "short_term_memory": None}
self.agents = None
self.summary_system_prompt = {}
self.summary_last_prompt = {}
self.environment_prompt = {}
self.environment_type = config["environment_type"] if "environment_type" in config else "cooperative"
self.current_chat_history_idx = 0
self.LLMs = {}
# 初始化每个state 的summary 方法
# Initialize the summary method for each state
for state_name, state_dict in config["states"].items():
if state_name != "end_state":
self.summary_system_prompt[state_name] = (
state_dict["summary_system_prompt"]
if "summary_system_prompt" in state_dict
else eval(Default_environment_summary_system_prompt)
)
self.summary_last_prompt[state_name] = (
state_dict["summary_last_prompt"]
if "summary_last_prompt" in state_dict
else eval(Default_environment_summary_last_prompt)
)
self.environment_prompt[state_name] = (
state_dict["environment_prompt"]
if "environment_prompt" in state_dict
else " "
)
self.LLMs[state_name] = init_LLM("logs"+os.sep+f"{state_name}",**state_dict)
self.roles_to_names = None
self.names_to_roles = None
@classmethod
def from_config(cls, config_path):
with open(config_path) as f:
config = json.load(f)
return cls(config)
def summary(self, current_state):
"""
Summarize the situation in the current environment every once in a while
"""
MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
current_state_name = current_state.name
query = self.shared_memory["long_term_memory"][-1].content
if len(self.shared_memory["long_term_memory"])>1:
relevant_history = get_relevant_history(
query,
self.shared_memory["long_term_memory"][:-1],
self.shared_memory["chat_embeddings"][:-1],
)
relevant_history = Memory.get_chat_history(relevant_history)
else:
relevant_history = ""
chat_history = Memory.get_chat_history(
self.shared_memory["long_term_memory"][-MAX_CHAT_HISTORY + 1 :]
)
summary = self.shared_memory["short_term_memory"]
# system prompt = environment prompt + current memory + system prompt
# current_memory = summary + chat history + relevant history
current_memory = eval(Environment_summary_memory)
environment_prompt = self.environment_prompt[current_state_name]
summary_system_prompt = self.summary_system_prompt[current_state_name]
environment_summary_system_prompt = eval(Environment_summary_system_prompt)
response = self.LLMs[current_state_name].get_response(None, environment_summary_system_prompt, stream=False)
return response
def update_memory(self, memory, current_state):
"""
update chat embbedings and long term memory,short term memory,agents long term memory
"""
MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
self.shared_memory["long_term_memory"].append(memory)
current_embedding = get_embedding(memory.content)
if "chat_embeddings" not in self.shared_memory:
self.shared_memory["chat_embeddings"] = current_embedding
else:
self.shared_memory["chat_embeddings"] = torch.cat(
[self.shared_memory["chat_embeddings"], current_embedding], dim=0
)
if len(self.shared_memory["long_term_memory"]) % MAX_CHAT_HISTORY == 0:
summary = self.summary(current_state)
self.shared_memory["short_term_memory"] = summary
self.agents[memory.send_name].update_memory(memory)
def _get_agent_last_conversation_idx(self,agent,current_long_term_memory):
last_conversation_idx = -1
for i, history in enumerate(current_long_term_memory):
if history.send_name == agent.name:
last_conversation_idx = i
return last_conversation_idx
def _get_agent_new_memory(self,agent,current_long_term_memory):
# get new conversation
last_conversation_idx = self._get_agent_last_conversation_idx(agent,current_long_term_memory)
if last_conversation_idx == -1:
new_conversation =current_long_term_memory
elif (
last_conversation_idx
== len(current_long_term_memory) - 1
):
new_conversation = []
else:
new_conversation = current_long_term_memory[
last_conversation_idx + 1 :
]
MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
if len(new_conversation) > 2 * MAX_CHAT_HISTORY:
new_conversation = new_conversation[-2*MAX_CHAT_HISTORY+1:]
# get chat history from new conversation
return Memory.get_chat_history(new_conversation)
def _observe(self,agent):
MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
current_state = agent.current_state
current_role = agent.state_roles[current_state.name]
current_component_dict = current_state.components[current_role]
# cooperative:Sharing information between different states ; competive: No information is shared between different states
current_chat_history_idx = self.current_chat_history_idx if self.environment_type == "competive" else 0
current_long_term_memory = self.shared_memory["long_term_memory"][current_chat_history_idx:]
current_chat_embbedings = self.shared_memory["chat_embeddings"][current_chat_history_idx:]
if len(current_long_term_memory)>2*MAX_CHAT_HISTORY:
current_long_term_memory = current_long_term_memory[-2*MAX_CHAT_HISTORY+1:]
current_chat_embbedings = current_chat_embbedings[-2*MAX_CHAT_HISTORY+1:]
# relevant_memory
query = current_long_term_memory[-1].content
if len(current_long_term_memory)>1:
relevant_memory = get_relevant_history(
query,
current_long_term_memory[:-2],
current_chat_embbedings[:-2],
)
relevant_memory = Memory.get_chat_history(relevant_memory,agent.name)
else:
relevant_memory = ""
relevant_memory = eval(Agent_observe_relevant_memory)
agent.relevant_memory = relevant_memory
# get chat history from new conversation
conversations = self._get_agent_new_memory(agent,current_long_term_memory)
# memory = relevant_memory + summary + history + query
query = current_long_term_memory[-1]
current_memory = eval(Agent_observe_memory)
return {"role": "user", "content": current_memory}