import itertools import logging import logging.handlers import os import pickle # nosec import time from concurrent import futures from queue import Queue from typing import Generator, List, Optional import async_inference_pb2 # type: ignore import async_inference_pb2_grpc # type: ignore import grpc import torch from datasets import load_dataset from lerobot.common.policies.factory import get_policy_class from lerobot.scripts.server.robot_client import ( TimedAction, TimedObservation, TinyPolicyConfig, environment_dt, ) # Create logs directory if it doesn't exist os.makedirs("logs", exist_ok=True) # Set up logging with both console and file output logger = logging.getLogger("policy_server") logger.setLevel(logging.INFO) # Console handler console_handler = logging.StreamHandler() console_handler.setFormatter( logging.Formatter("%(asctime)s [SERVER] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") ) logger.addHandler(console_handler) # File handler - creates a new log file for each run file_handler = logging.handlers.RotatingFileHandler( f"logs/policy_server_{int(time.time())}.log", maxBytes=10 * 1024 * 1024, # 10MB backupCount=5, ) file_handler.setFormatter( logging.Formatter("%(asctime)s [SERVER] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") ) logger.addHandler(file_handler) inference_latency = 1 / 3 idle_wait = 0.1 supported_policies = ["act"] class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): def __init__(self): # Initialize dataset action generator self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset()) self._setup_server() self.actions_per_chunk = 20 self.actions_overlap = 10 self.running = True # Add a running flag to control server lifetime def _setup_server(self) -> None: """Flushes server state when new client connects.""" # only running inference on the latest observation received by the server self.observation_queue = Queue(maxsize=1) def Ready(self, request, context): # noqa: N802 client_id = context.peer() logger.info(f"Client {client_id} connected and ready") self._setup_server() return async_inference_pb2.Empty() def SendPolicyInstructions(self, request, context): # noqa: N802 """Receive policy instructions from the robot client""" client_id = context.peer() logger.debug(f"Receiving policy instructions from {client_id}") policy_specs = pickle.loads(request.data) # nosec assert isinstance(policy_specs, TinyPolicyConfig), ( f"Policy specs must be a TinyPolicyConfig. Got {type(policy_specs)}" ) logger.info( f"Policy type: {policy_specs.policy_type} | " f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | " f"Device: {policy_specs.device}" ) assert policy_specs.policy_type in supported_policies, ( f"Policy type {policy_specs.policy_type} not supported. Supported policies: {supported_policies}" ) self.device = policy_specs.device policy_class = get_policy_class(policy_specs.policy_type) start = time.time() self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path) self.policy.to(self.device) end = time.time() logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") return async_inference_pb2.Empty() def SendObservations(self, request_iterator, context): # noqa: N802 """Receive observations from the robot client""" client_id = context.peer() logger.debug(f"Receiving observations from {client_id}") for observation in request_iterator: receive_time = time.time() timed_observation = pickle.loads(observation.data) # nosec deserialize_time = time.time() # If queue is full, get the old observation to make room if self.observation_queue.full(): # pops from queue _ = self.observation_queue.get_nowait() logger.debug("Observation queue was full, removed oldest observation") # Now put the new observation (never blocks as queue is non-full here) self.observation_queue.put(timed_observation) queue_time = time.time() obs_timestep = timed_observation.get_timestep() obs_timestamp = timed_observation.get_timestamp() logger.info( f"Received observation #{obs_timestep} | " f"Client timestamp: {obs_timestamp:.6f} | " f"Server timestamp: {receive_time:.6f} | " f"Network latency: {receive_time - obs_timestamp:.6f}s | " f"Deserialization time: {deserialize_time - receive_time:.6f}s | " f"Queue time: {queue_time - deserialize_time:.6f}s" ) return async_inference_pb2.Empty() def StreamActions(self, request, context): # noqa: N802 """Stream actions to the robot client""" client_id = context.peer() logger.debug(f"Client {client_id} connected for action streaming") # Generate action based on the most recent observation and its timestep start_time = time.time() try: obs = self.observation_queue.get() get_time = time.time() logger.info( f"Running inference for observation #{obs.get_timestep()} | Queue get time: {get_time - start_time:.6f}s" ) if obs: action = self._predict_action_chunk(obs) inference_end_time = time.time() logger.info( f"Action chunk #{obs.get_timestep()} generated | " f"Total inference time: {inference_end_time - get_time:.6f}s" ) yield action yield_time = time.time() logger.info( f"Action chunk #{obs.get_timestep()} sent | Send time: {yield_time - inference_end_time:.6f}s" ) else: logger.warning("No observation in queue yet!") time.sleep(idle_wait) except Exception as e: logger.error(f"Error in StreamActions: {e}") return async_inference_pb2.Empty() def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]: """Turn a chunk of actions into a list of TimedAction instances, with the first action corresponding to t_0 and the rest corresponding to t_0 + i*environment_dt for i in range(len(action_chunk)) """ return [ TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk) ] @torch.no_grad() def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: # NOTE: This temporary function only works for ACT policies (Pi0-like models are *not* supported just yet) """Get an action chunk from the policy""" start_time = time.time() # prepare observation for policy forward pass batch = self.policy.normalize_inputs(observation) normalize_time = time.time() logger.debug(f"Observation normalization time: {normalize_time - start_time:.6f}s") if self.policy.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch["observation.images"] = [batch[key] for key in self.policy.config.image_features] prep_time = time.time() logger.debug(f"Observation image preparation time: {prep_time - normalize_time:.6f}s") # forward pass outputs up to policy.config.n_action_steps != actions_per_chunk forward_start = time.time() actions = self.policy.model(batch)[0][:, : self.actions_per_chunk] forward_end = time.time() logger.debug(f"Policy forward pass time: {forward_end - forward_start:.6f}s") actions = self.policy.unnormalize_outputs({"action": actions})["action"] unnormalize_end = time.time() logger.debug(f"Action unnormalization time: {unnormalize_end - forward_end:.6f}s") end_time = time.time() logger.info(f"Action chunk generation total time: {end_time - start_time:.6f}s") return actions def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]: """Predict an action based on the observation""" start_time = time.time() observation = {} for k, v in observation_t.get_observation().items(): if "image" in k: observation[k] = v.permute(2, 0, 1).unsqueeze(0).to(self.device) else: observation[k] = v.unsqueeze(0).to(self.device) prep_time = time.time() logger.debug(f"Observation preparation time: {prep_time - start_time:.6f}s") # normalize observation observation = self.policy.normalize_inputs(observation) # Remove batch dimension action_tensor = self._get_action_chunk(observation) action_tensor = action_tensor.squeeze(0) # Move to CPU before serializing action_tensor = action_tensor.cpu() post_inference_time = time.time() logger.debug(f"Post-inference processing start: {post_inference_time - prep_time:.6f}s") if action_tensor.dim() == 1: # No chunk dimension, so repeat action to create a (dummy) chunk of actions action_tensor = action_tensor.repeat(self.actions_per_chunk, 1) action_chunk = self._time_action_chunk( observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() ) chunk_time = time.time() logger.debug(f"Action chunk creation time: {chunk_time - post_inference_time:.6f}s") action_bytes = pickle.dumps(action_chunk) # nosec serialize_time = time.time() logger.debug(f"Action serialization time: {serialize_time - chunk_time:.6f}s") # Create and return the Action message action = async_inference_pb2.Action(transfer_state=observation_t.transfer_state, data=action_bytes) end_time = time.time() logger.info( f"Total action prediction time: {end_time - start_time:.6f}s | " f"Observation #{observation_t.get_timestep()} | " f"Action chunk size: {len(action_chunk)}" ) time.sleep(inference_latency) return action def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]: """Stream chunks of actions from a prerecorded dataset. Returns: Generator that yields chunks of actions from the dataset """ dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch") # 1. Select the action column only, where you will find tensors with 6 elements actions = dataset["action"] action_indices = torch.arange(len(actions)) # 2. Chunk the iterable of tensors into chunks with 10 elements each # sending only first element for debugging indices_chunks = action_indices.unfold( 0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap ) for idx_chunk in indices_chunks: yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :] def _read_action_chunk(self, observation: Optional[TimedObservation] = None): """Dummy function for predicting action chunk given observation. Instead of computing actions on-the-fly, this method streams actions from a prerecorded dataset. """ import warnings warnings.warn( "This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2 ) if not observation: observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0) transfer_state = 0 else: transfer_state = observation.transfer_state # Get chunk of actions from the generator actions_chunk = next(self.action_generator) # Return a list of TimedActions, with timestamps starting from the observation timestamp action_data = self._time_action_chunk( observation.get_timestamp(), actions_chunk, observation.get_timestep() ) action_bytes = pickle.dumps(action_data) # nosec # Create and return the Action message action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_bytes) time.sleep(inference_latency) # slow action generation, emulates inference time return action def stop(self): """Stop the server""" self.running = False logger.info("Server stopping...") def serve(): PORT = 8080 # Create the server instance first policy_server = PolicyServer() # Setup and start gRPC server server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) server.add_insecure_port(f"[::]:{PORT}") server.start() logger.info(f"PolicyServer started on port {PORT}") print(f"PolicyServer started on port {PORT}") try: # Use the running attribute to control server lifetime while policy_server.running: time.sleep(1) # Check every second instead of sleeping indefinitely except KeyboardInterrupt: policy_server.stop() logger.info("Keyboard interrupt received") finally: server.stop(0) logger.info("Server stopped") if __name__ == "__main__": serve()