Spaces:
Sleeping
Sleeping
import copy | |
from typing import List, Union | |
from lagent.agents import Agent, AgentForInternLM, AsyncAgent, AsyncAgentForInternLM | |
from lagent.schema import AgentMessage, AgentStatusCode, ModelStatusCode | |
class StreamingAgentMixin: | |
"""Make agent calling output a streaming response.""" | |
def __call__(self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs): | |
for hook in self._hooks.values(): | |
message = copy.deepcopy(message) | |
result = hook.before_agent(self, message, session_id) | |
if result: | |
message = result | |
self.update_memory(message, session_id=session_id) | |
response_message = AgentMessage(sender=self.name, content="") | |
for response_message in self.forward(*message, session_id=session_id, **kwargs): | |
if not isinstance(response_message, AgentMessage): | |
model_state, response = response_message | |
response_message = AgentMessage( | |
sender=self.name, | |
content=response, | |
stream_state=model_state, | |
) | |
yield response_message.model_copy() | |
self.update_memory(response_message, session_id=session_id) | |
for hook in self._hooks.values(): | |
response_message = response_message.model_copy(deep=True) | |
result = hook.after_agent(self, response_message, session_id) | |
if result: | |
response_message = result | |
yield response_message | |
class AsyncStreamingAgentMixin: | |
"""Make asynchronous agent calling output a streaming response.""" | |
async def __call__( | |
self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs | |
): | |
for hook in self._hooks.values(): | |
message = copy.deepcopy(message) | |
result = hook.before_agent(self, message, session_id) | |
if result: | |
message = result | |
self.update_memory(message, session_id=session_id) | |
response_message = AgentMessage(sender=self.name, content="") | |
async for response_message in self.forward(*message, session_id=session_id, **kwargs): | |
if not isinstance(response_message, AgentMessage): | |
model_state, response = response_message | |
response_message = AgentMessage( | |
sender=self.name, | |
content=response, | |
stream_state=model_state, | |
) | |
yield response_message.model_copy() | |
self.update_memory(response_message, session_id=session_id) | |
for hook in self._hooks.values(): | |
response_message = response_message.model_copy(deep=True) | |
result = hook.after_agent(self, response_message, session_id) | |
if result: | |
response_message = result | |
yield response_message | |
class StreamingAgent(StreamingAgentMixin, Agent): | |
"""Base streaming agent class""" | |
def forward(self, *message: AgentMessage, session_id=0, **kwargs): | |
formatted_messages = self.aggregator.aggregate( | |
self.memory.get(session_id), | |
self.name, | |
self.output_format, | |
self.template, | |
) | |
for model_state, response, _ in self.llm.stream_chat( | |
formatted_messages, session_id=session_id, **kwargs | |
): | |
yield AgentMessage( | |
sender=self.name, | |
content=response, | |
formatted=self.output_format.parse_response(response), | |
stream_state=model_state, | |
) if self.output_format else (model_state, response) | |
class AsyncStreamingAgent(AsyncStreamingAgentMixin, AsyncAgent): | |
"""Base asynchronous streaming agent class""" | |
async def forward(self, *message: AgentMessage, session_id=0, **kwargs): | |
formatted_messages = self.aggregator.aggregate( | |
self.memory.get(session_id), | |
self.name, | |
self.output_format, | |
self.template, | |
) | |
async for model_state, response, _ in self.llm.stream_chat( | |
formatted_messages, session_id=session_id, **kwargs | |
): | |
yield AgentMessage( | |
sender=self.name, | |
content=response, | |
formatted=self.output_format.parse_response(response), | |
stream_state=model_state, | |
) if self.output_format else (model_state, response) | |
class StreamingAgentForInternLM(StreamingAgentMixin, AgentForInternLM): | |
"""Streaming implementation of `lagent.agents.AgentForInternLM`""" | |
_INTERNAL_AGENT_CLS = StreamingAgent | |
def forward(self, message: AgentMessage, session_id=0, **kwargs): | |
if isinstance(message, str): | |
message = AgentMessage(sender="user", content=message) | |
for _ in range(self.max_turn): | |
last_agent_state = AgentStatusCode.SESSION_READY | |
for message in self.agent(message, session_id=session_id, **kwargs): | |
if isinstance(message.formatted, dict) and message.formatted.get("tool_type"): | |
if message.stream_state == ModelStatusCode.END: | |
message.stream_state = last_agent_state + int( | |
last_agent_state | |
in [ | |
AgentStatusCode.CODING, | |
AgentStatusCode.PLUGIN_START, | |
] | |
) | |
else: | |
message.stream_state = ( | |
AgentStatusCode.PLUGIN_START | |
if message.formatted["tool_type"] == "plugin" | |
else AgentStatusCode.CODING | |
) | |
else: | |
message.stream_state = AgentStatusCode.STREAM_ING | |
yield message | |
last_agent_state = message.stream_state | |
if self.finish_condition(message): | |
message.stream_state = AgentStatusCode.END | |
yield message | |
return | |
if message.formatted["tool_type"]: | |
tool_type = message.formatted["tool_type"] | |
executor = getattr(self, f"{tool_type}_executor", None) | |
if not executor: | |
raise RuntimeError(f"No available {tool_type} executor") | |
tool_return = executor(message, session_id=session_id) | |
tool_return.stream_state = message.stream_state + 1 | |
message = tool_return | |
yield message | |
else: | |
message.stream_state = AgentStatusCode.STREAM_ING | |
yield message | |
class AsyncStreamingAgentForInternLM(AsyncStreamingAgentMixin, AsyncAgentForInternLM): | |
"""Streaming implementation of `lagent.agents.AsyncAgentForInternLM`""" | |
_INTERNAL_AGENT_CLS = AsyncStreamingAgent | |
async def forward(self, message: AgentMessage, session_id=0, **kwargs): | |
if isinstance(message, str): | |
message = AgentMessage(sender="user", content=message) | |
for _ in range(self.max_turn): | |
last_agent_state = AgentStatusCode.SESSION_READY | |
async for message in self.agent(message, session_id=session_id, **kwargs): | |
if isinstance(message.formatted, dict) and message.formatted.get("tool_type"): | |
if message.stream_state == ModelStatusCode.END: | |
message.stream_state = last_agent_state + int( | |
last_agent_state | |
in [ | |
AgentStatusCode.CODING, | |
AgentStatusCode.PLUGIN_START, | |
] | |
) | |
else: | |
message.stream_state = ( | |
AgentStatusCode.PLUGIN_START | |
if message.formatted["tool_type"] == "plugin" | |
else AgentStatusCode.CODING | |
) | |
else: | |
message.stream_state = AgentStatusCode.STREAM_ING | |
yield message | |
last_agent_state = message.stream_state | |
if self.finish_condition(message): | |
message.stream_state = AgentStatusCode.END | |
yield message | |
return | |
if message.formatted["tool_type"]: | |
tool_type = message.formatted["tool_type"] | |
executor = getattr(self, f"{tool_type}_executor", None) | |
if not executor: | |
raise RuntimeError(f"No available {tool_type} executor") | |
tool_return = await executor(message, session_id=session_id) | |
tool_return.stream_state = message.stream_state + 1 | |
message = tool_return | |
yield message | |
else: | |
message.stream_state = AgentStatusCode.STREAM_ING | |
yield message | |