| |
| """ |
| Cap'n Proto RPC Server for Agent Interface |
| """ |
|
|
| import asyncio |
| import logging |
| import os |
| import pickle |
| import numpy as np |
| import torch |
| import capnp |
|
|
| |
| schema_file = os.path.join(os.path.dirname(__file__), "agent.capnp") |
| agent_capnp = capnp.load(schema_file) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class AgentServer(agent_capnp.Agent.Server): |
| """Cap'n Proto server implementation for AgentInterface""" |
|
|
| def __init__(self, agent): |
| self.agent = agent |
| self.logger = logging.getLogger(__name__) |
| self.logger.info("AgentServer initialized with agent: %s", type(agent).__name__) |
|
|
| async def act(self, obs, **kwargs): |
| """Handle act RPC call""" |
| try: |
| |
| observation = pickle.loads(obs) |
|
|
| |
| action_tensor = self.agent.act(observation) |
|
|
| |
| if isinstance(action_tensor, torch.Tensor): |
| action_numpy = action_tensor.detach().cpu().numpy() |
| else: |
| action_numpy = np.array(action_tensor) |
|
|
| |
| response = agent_capnp.Agent.Tensor.new_message() |
| response.data = action_numpy.tobytes() |
| response.shape = list(action_numpy.shape) |
| response.dtype = str(action_numpy.dtype) |
|
|
| return response |
| except Exception as e: |
| self.logger.error(f"Error in act: {e}", exc_info=True) |
| raise |
|
|
| async def reset(self, **kwargs): |
| """Handle reset RPC call""" |
| try: |
| self.agent.reset() |
| except Exception as e: |
| self.logger.error(f"Error in reset: {e}", exc_info=True) |
| raise |
|
|
|
|
| async def serve(agent, address="127.0.0.1", port=8000): |
| """Serve the agent using asyncio approach""" |
|
|
| async def new_connection(stream): |
| """Handler for each new client connection""" |
| try: |
| |
| server = capnp.TwoPartyServer(stream, bootstrap=AgentServer(agent)) |
|
|
| |
| await server.on_disconnect() |
|
|
| except Exception as e: |
| logger.error(f"Error handling connection: {e}", exc_info=True) |
|
|
| |
| server = await capnp.AsyncIoStream.create_server(new_connection, address, port) |
|
|
| logger.info(f"Agent RPC server listening on {address}:{port}") |
|
|
| try: |
| |
| async with server: |
| await server.serve_forever() |
| except Exception as e: |
| logger.error(f"Server error: {e}", exc_info=True) |
| finally: |
| logger.info("Server shutting down") |
|
|
|
|
| def start_server(agent, address="127.0.0.1", port=8000): |
| """Start server with proper asyncio event loop handling""" |
|
|
| async def run_server_with_kj(): |
| async with capnp.kj_loop(): |
| await serve(agent, address, port) |
|
|
| try: |
| asyncio.run(run_server_with_kj()) |
| except KeyboardInterrupt: |
| logger.info("Server stopped by user") |
|
|
|
|
| def run_server_in_process(agent, address="127.0.0.1", port=8000): |
| """Entry point for running server in a separate process""" |
|
|
| async def run_with_kj(): |
| async with capnp.kj_loop(): |
| await serve(agent, address, port) |
|
|
| asyncio.run(run_with_kj()) |
|
|