Spaces:
Sleeping
Sleeping
import json | |
import tempfile | |
import requests | |
import streamlit as st | |
from lagent.schema import AgentStatusCode | |
from pyvis.network import Network | |
# Function to create the network graph | |
def create_network_graph(nodes, adjacency_list): | |
net = Network(height="500px", width="60%", bgcolor="white", font_color="black") | |
for node_id, node_content in nodes.items(): | |
net.add_node(node_id, label=node_id, title=node_content, color="#FF5733", size=25) | |
for node_id, neighbors in adjacency_list.items(): | |
for neighbor in neighbors: | |
if neighbor["name"] in nodes: | |
net.add_edge(node_id, neighbor["name"]) | |
net.show_buttons(filter_=["physics"]) | |
return net | |
# Function to draw the graph and return the HTML file path | |
def draw_graph(net): | |
path = tempfile.mktemp(suffix=".html") | |
net.save_graph(path) | |
return path | |
def streaming(raw_response): | |
for chunk in raw_response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\n"): | |
if chunk: | |
decoded = chunk.decode("utf-8") | |
if decoded == "\r": | |
continue | |
if decoded[:6] == "data: ": | |
decoded = decoded[6:] | |
elif decoded.startswith(": ping - "): | |
continue | |
response = json.loads(decoded) | |
yield ( | |
response["current_node"], | |
( | |
response["response"]["formatted"]["node"][response["current_node"]]["response"] | |
if response["current_node"] | |
else response["response"] | |
), | |
response["response"]["formatted"]["adjacency_list"], | |
) | |
# Initialize Streamlit session state | |
if "queries" not in st.session_state: | |
st.session_state["queries"] = [] | |
st.session_state["responses"] = [] | |
st.session_state["graphs_html"] = [] | |
st.session_state["nodes_list"] = [] | |
st.session_state["adjacency_list_list"] = [] | |
st.session_state["history"] = [] | |
st.session_state["already_used_keys"] = list() | |
# Set up page layout | |
st.set_page_config(layout="wide") | |
st.title("MindSearch-思索") | |
# Function to update chat | |
def update_chat(query): | |
with st.chat_message("user"): | |
st.write(query) | |
if query not in st.session_state["queries"]: | |
# Mock data to simulate backend response | |
# response, history, nodes, adjacency_list | |
st.session_state["queries"].append(query) | |
st.session_state["responses"].append([]) | |
history = None | |
# 暂不支持多轮 | |
# message = [dict(role='user', content=query)] | |
url = "http://localhost:8002/solve" | |
headers = {"Content-Type": "application/json"} | |
data = {"inputs": query} | |
raw_response = requests.post( | |
url, headers=headers, data=json.dumps(data), timeout=20, stream=True | |
) | |
_nodes, _node_cnt = {}, 0 | |
for resp in streaming(raw_response): | |
node_name, response, adjacency_list = resp | |
for name in set(adjacency_list) | { | |
val["name"] for vals in adjacency_list.values() for val in vals | |
}: | |
if name not in _nodes: | |
_nodes[name] = query if name == "root" else name | |
elif response["stream_state"] == 0: | |
_nodes[node_name or "response"] = response["formatted"] and response[ | |
"formatted" | |
].get("thought") | |
if len(_nodes) != _node_cnt or response["stream_state"] == 0: | |
net = create_network_graph(_nodes, adjacency_list) | |
graph_html_path = draw_graph(net) | |
with open(graph_html_path, encoding="utf-8") as f: | |
graph_html = f.read() | |
_node_cnt = len(_nodes) | |
else: | |
graph_html = None | |
if "graph_placeholder" not in st.session_state: | |
st.session_state["graph_placeholder"] = st.empty() | |
if "expander_placeholder" not in st.session_state: | |
st.session_state["expander_placeholder"] = st.empty() | |
if graph_html: | |
with st.session_state["expander_placeholder"].expander( | |
"Show Graph", expanded=False | |
): | |
st.session_state["graph_placeholder"]._html(graph_html, height=500) | |
if "container_placeholder" not in st.session_state: | |
st.session_state["container_placeholder"] = st.empty() | |
with st.session_state["container_placeholder"].container(): | |
if "columns_placeholder" not in st.session_state: | |
st.session_state["columns_placeholder"] = st.empty() | |
col1, col2 = st.session_state["columns_placeholder"].columns([2, 1]) | |
with col1: | |
if "planner_placeholder" not in st.session_state: | |
st.session_state["planner_placeholder"] = st.empty() | |
if "session_info_temp" not in st.session_state: | |
st.session_state["session_info_temp"] = "" | |
if not node_name: | |
if response["stream_state"] in [ | |
AgentStatusCode.STREAM_ING, | |
AgentStatusCode.CODING, | |
AgentStatusCode.CODE_END, | |
]: | |
content = response["formatted"]["thought"] | |
if response["formatted"]["tool_type"]: | |
action = response["formatted"]["action"] | |
if isinstance(action, dict): | |
action = json.dumps(action, ensure_ascii=False, indent=4) | |
content += "\n" + action | |
st.session_state["session_info_temp"] = content.replace( | |
"<|action_start|><|interpreter|>\n", "\n" | |
) | |
elif response["stream_state"] == AgentStatusCode.CODE_RETURN: | |
# assert history[-1]["role"] == "environment" | |
st.session_state["session_info_temp"] += "\n" + response["content"] | |
st.session_state["planner_placeholder"].markdown( | |
st.session_state["session_info_temp"] | |
) | |
if response["stream_state"] == AgentStatusCode.CODE_RETURN: | |
st.session_state["responses"][-1].append( | |
st.session_state["session_info_temp"] | |
) | |
st.session_state["session_info_temp"] = "" | |
else: | |
st.session_state["planner_placeholder"].markdown( | |
st.session_state["responses"][-1][-1] | |
if not st.session_state["session_info_temp"] | |
else st.session_state["session_info_temp"] | |
) | |
with col2: | |
if "selectbox_placeholder" not in st.session_state: | |
st.session_state["selectbox_placeholder"] = st.empty() | |
if "searcher_placeholder" not in st.session_state: | |
st.session_state["searcher_placeholder"] = st.empty() | |
if node_name: | |
selected_node_key = ( | |
f"selected_node_{len(st.session_state['queries'])}_{node_name}" | |
) | |
if selected_node_key not in st.session_state: | |
st.session_state[selected_node_key] = node_name | |
if selected_node_key not in st.session_state["already_used_keys"]: | |
selected_node = st.session_state["selectbox_placeholder"].selectbox( | |
"Select a node:", | |
list(_nodes.keys()), | |
key=f"key_{selected_node_key}", | |
index=list(_nodes.keys()).index(node_name), | |
) | |
st.session_state["already_used_keys"].append(selected_node_key) | |
else: | |
selected_node = node_name | |
st.session_state[selected_node_key] = selected_node | |
node_info_key = f"{selected_node}_info" | |
if node_info_key not in st.session_state: | |
st.session_state[node_info_key] = [["thought", ""]] | |
if response["stream_state"] in [AgentStatusCode.STREAM_ING]: | |
content = response["formatted"]["thought"] | |
st.session_state[node_info_key][-1][1] = content.replace( | |
"<|action_start|><|plugin|>\n", "\n```json\n" | |
) | |
elif response["stream_state"] in [ | |
AgentStatusCode.PLUGIN_START, | |
AgentStatusCode.PLUGIN_END, | |
]: | |
thought = response["formatted"]["thought"] | |
action = response["formatted"]["action"] | |
if isinstance(action, dict): | |
action = json.dumps(action, ensure_ascii=False, indent=4) | |
content = thought + "\n```json\n" + action | |
if response["stream_state"] == AgentStatusCode.PLUGIN_RETURN: | |
content += "\n```" | |
st.session_state[node_info_key][-1][1] = content | |
elif ( | |
response["stream_state"] == AgentStatusCode.PLUGIN_RETURN | |
and st.session_state[node_info_key][-1][1] | |
): | |
try: | |
content = json.loads(response["content"]) | |
except json.decoder.JSONDecodeError: | |
content = response["content"] | |
st.session_state[node_info_key].append( | |
[ | |
"observation", | |
( | |
content | |
if isinstance(content, str) | |
else f"```json\n{json.dumps(content, ensure_ascii=False, indent=4)}\n```" | |
), | |
] | |
) | |
st.session_state["searcher_placeholder"].markdown( | |
st.session_state[node_info_key][-1][1] | |
) | |
if ( | |
response["stream_state"] == AgentStatusCode.PLUGIN_RETURN | |
and st.session_state[node_info_key][-1][1] | |
): | |
st.session_state[node_info_key].append(["thought", ""]) | |
if st.session_state["session_info_temp"]: | |
st.session_state["responses"][-1].append(st.session_state["session_info_temp"]) | |
st.session_state["session_info_temp"] = "" | |
# st.session_state['responses'][-1] = '\n'.join(st.session_state['responses'][-1]) | |
st.session_state["graphs_html"].append(graph_html) | |
st.session_state["nodes_list"].append(_nodes) | |
st.session_state["adjacency_list_list"].append(adjacency_list) | |
st.session_state["history"] = history | |
def display_chat_history(): | |
for i, query in enumerate(st.session_state["queries"][-1:]): | |
# with st.chat_message('assistant'): | |
if st.session_state["graphs_html"][i]: | |
with st.session_state["expander_placeholder"].expander("Show Graph", expanded=False): | |
st.session_state["graph_placeholder"]._html( | |
st.session_state["graphs_html"][i], height=500 | |
) | |
with st.session_state["container_placeholder"].container(): | |
col1, col2 = st.session_state["columns_placeholder"].columns([2, 1]) | |
with col1: | |
st.session_state["planner_placeholder"].markdown( | |
st.session_state["responses"][-1][-1] | |
) | |
with col2: | |
selected_node_key = st.session_state["already_used_keys"][-1] | |
st.session_state["selectbox_placeholder"] = st.empty() | |
selected_node = st.session_state["selectbox_placeholder"].selectbox( | |
"Select a node:", | |
list(st.session_state["nodes_list"][i].keys()), | |
key=f"replay_key_{i}", | |
index=list(st.session_state["nodes_list"][i].keys()).index( | |
st.session_state[selected_node_key] | |
), | |
) | |
st.session_state[selected_node_key] = selected_node | |
if ( | |
selected_node not in ["root", "response"] | |
and selected_node in st.session_state["nodes_list"][i] | |
): | |
node_info_key = f"{selected_node}_info" | |
for item in st.session_state[node_info_key]: | |
if item[0] in ["thought", "answer"]: | |
st.session_state["searcher_placeholder"] = st.empty() | |
st.session_state["searcher_placeholder"].markdown(item[1]) | |
elif item[0] == "observation": | |
st.session_state["observation_expander"] = st.empty() | |
with st.session_state["observation_expander"].expander("Results"): | |
st.write(item[1]) | |
# st.session_state['searcher_placeholder'].markdown(st.session_state[node_info_key]) | |
def clean_history(): | |
st.session_state["queries"] = [] | |
st.session_state["responses"] = [] | |
st.session_state["graphs_html"] = [] | |
st.session_state["nodes_list"] = [] | |
st.session_state["adjacency_list_list"] = [] | |
st.session_state["history"] = [] | |
st.session_state["already_used_keys"] = list() | |
for k in st.session_state: | |
if k.endswith("placeholder") or k.endswith("_info"): | |
del st.session_state[k] | |
# Main function to run the Streamlit app | |
def main(): | |
st.sidebar.title("Model Control") | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
user_input = st.chat_input("Enter your query:") | |
with col2: | |
if st.button("Clear History"): | |
clean_history() | |
if user_input: | |
update_chat(user_input) | |
display_chat_history() | |
if __name__ == "__main__": | |
main() | |