MindSearch / mindsearch /agent /streaming.py
vansin's picture
feat: update
dc9e27a
raw
history blame
9.03 kB
import copy
from typing import List, Union
from lagent.agents import Agent, AgentForInternLM, AsyncAgent, AsyncAgentForInternLM
from lagent.schema import AgentMessage, AgentStatusCode, ModelStatusCode
class StreamingAgentMixin:
"""Make agent calling output a streaming response."""
def __call__(self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs):
for hook in self._hooks.values():
message = copy.deepcopy(message)
result = hook.before_agent(self, message, session_id)
if result:
message = result
self.update_memory(message, session_id=session_id)
response_message = AgentMessage(sender=self.name, content="")
for response_message in self.forward(*message, session_id=session_id, **kwargs):
if not isinstance(response_message, AgentMessage):
model_state, response = response_message
response_message = AgentMessage(
sender=self.name,
content=response,
stream_state=model_state,
)
yield response_message.model_copy()
self.update_memory(response_message, session_id=session_id)
for hook in self._hooks.values():
response_message = response_message.model_copy(deep=True)
result = hook.after_agent(self, response_message, session_id)
if result:
response_message = result
yield response_message
class AsyncStreamingAgentMixin:
"""Make asynchronous agent calling output a streaming response."""
async def __call__(
self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs
):
for hook in self._hooks.values():
message = copy.deepcopy(message)
result = hook.before_agent(self, message, session_id)
if result:
message = result
self.update_memory(message, session_id=session_id)
response_message = AgentMessage(sender=self.name, content="")
async for response_message in self.forward(*message, session_id=session_id, **kwargs):
if not isinstance(response_message, AgentMessage):
model_state, response = response_message
response_message = AgentMessage(
sender=self.name,
content=response,
stream_state=model_state,
)
yield response_message.model_copy()
self.update_memory(response_message, session_id=session_id)
for hook in self._hooks.values():
response_message = response_message.model_copy(deep=True)
result = hook.after_agent(self, response_message, session_id)
if result:
response_message = result
yield response_message
class StreamingAgent(StreamingAgentMixin, Agent):
"""Base streaming agent class"""
def forward(self, *message: AgentMessage, session_id=0, **kwargs):
formatted_messages = self.aggregator.aggregate(
self.memory.get(session_id),
self.name,
self.output_format,
self.template,
)
for model_state, response, _ in self.llm.stream_chat(
formatted_messages, session_id=session_id, **kwargs
):
yield AgentMessage(
sender=self.name,
content=response,
formatted=self.output_format.parse_response(response),
stream_state=model_state,
) if self.output_format else (model_state, response)
class AsyncStreamingAgent(AsyncStreamingAgentMixin, AsyncAgent):
"""Base asynchronous streaming agent class"""
async def forward(self, *message: AgentMessage, session_id=0, **kwargs):
formatted_messages = self.aggregator.aggregate(
self.memory.get(session_id),
self.name,
self.output_format,
self.template,
)
async for model_state, response, _ in self.llm.stream_chat(
formatted_messages, session_id=session_id, **kwargs
):
yield AgentMessage(
sender=self.name,
content=response,
formatted=self.output_format.parse_response(response),
stream_state=model_state,
) if self.output_format else (model_state, response)
class StreamingAgentForInternLM(StreamingAgentMixin, AgentForInternLM):
"""Streaming implementation of `lagent.agents.AgentForInternLM`"""
_INTERNAL_AGENT_CLS = StreamingAgent
def forward(self, message: AgentMessage, session_id=0, **kwargs):
if isinstance(message, str):
message = AgentMessage(sender="user", content=message)
for _ in range(self.max_turn):
last_agent_state = AgentStatusCode.SESSION_READY
for message in self.agent(message, session_id=session_id, **kwargs):
if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
if message.stream_state == ModelStatusCode.END:
message.stream_state = last_agent_state + int(
last_agent_state
in [
AgentStatusCode.CODING,
AgentStatusCode.PLUGIN_START,
]
)
else:
message.stream_state = (
AgentStatusCode.PLUGIN_START
if message.formatted["tool_type"] == "plugin"
else AgentStatusCode.CODING
)
else:
message.stream_state = AgentStatusCode.STREAM_ING
yield message
last_agent_state = message.stream_state
if self.finish_condition(message):
message.stream_state = AgentStatusCode.END
yield message
return
if message.formatted["tool_type"]:
tool_type = message.formatted["tool_type"]
executor = getattr(self, f"{tool_type}_executor", None)
if not executor:
raise RuntimeError(f"No available {tool_type} executor")
tool_return = executor(message, session_id=session_id)
tool_return.stream_state = message.stream_state + 1
message = tool_return
yield message
else:
message.stream_state = AgentStatusCode.STREAM_ING
yield message
class AsyncStreamingAgentForInternLM(AsyncStreamingAgentMixin, AsyncAgentForInternLM):
"""Streaming implementation of `lagent.agents.AsyncAgentForInternLM`"""
_INTERNAL_AGENT_CLS = AsyncStreamingAgent
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
if isinstance(message, str):
message = AgentMessage(sender="user", content=message)
for _ in range(self.max_turn):
last_agent_state = AgentStatusCode.SESSION_READY
async for message in self.agent(message, session_id=session_id, **kwargs):
if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
if message.stream_state == ModelStatusCode.END:
message.stream_state = last_agent_state + int(
last_agent_state
in [
AgentStatusCode.CODING,
AgentStatusCode.PLUGIN_START,
]
)
else:
message.stream_state = (
AgentStatusCode.PLUGIN_START
if message.formatted["tool_type"] == "plugin"
else AgentStatusCode.CODING
)
else:
message.stream_state = AgentStatusCode.STREAM_ING
yield message
last_agent_state = message.stream_state
if self.finish_condition(message):
message.stream_state = AgentStatusCode.END
yield message
return
if message.formatted["tool_type"]:
tool_type = message.formatted["tool_type"]
executor = getattr(self, f"{tool_type}_executor", None)
if not executor:
raise RuntimeError(f"No available {tool_type} executor")
tool_return = await executor(message, session_id=session_id)
tool_return.stream_state = message.stream_state + 1
message = tool_return
yield message
else:
message.stream_state = AgentStatusCode.STREAM_ING
yield message