Spaces:
Running
Running
import itertools | |
import logging | |
import logging.handlers | |
import os | |
import pickle # nosec | |
import time | |
from concurrent import futures | |
from queue import Queue | |
from typing import Generator, List, Optional | |
import async_inference_pb2 # type: ignore | |
import async_inference_pb2_grpc # type: ignore | |
import grpc | |
import torch | |
from datasets import load_dataset | |
from lerobot.common.policies.factory import get_policy_class | |
from lerobot.scripts.server.robot_client import ( | |
TimedAction, | |
TimedObservation, | |
TinyPolicyConfig, | |
environment_dt, | |
) | |
# Create logs directory if it doesn't exist | |
os.makedirs("logs", exist_ok=True) | |
# Set up logging with both console and file output | |
logger = logging.getLogger("policy_server") | |
logger.setLevel(logging.INFO) | |
# Console handler | |
console_handler = logging.StreamHandler() | |
console_handler.setFormatter( | |
logging.Formatter("%(asctime)s [SERVER] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
) | |
logger.addHandler(console_handler) | |
# File handler - creates a new log file for each run | |
file_handler = logging.handlers.RotatingFileHandler( | |
f"logs/policy_server_{int(time.time())}.log", | |
maxBytes=10 * 1024 * 1024, # 10MB | |
backupCount=5, | |
) | |
file_handler.setFormatter( | |
logging.Formatter("%(asctime)s [SERVER] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
) | |
logger.addHandler(file_handler) | |
inference_latency = 1 / 3 | |
idle_wait = 0.1 | |
supported_policies = ["act"] | |
class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer): | |
def __init__(self): | |
# Initialize dataset action generator | |
self.action_generator = itertools.cycle(self._stream_action_chunks_from_dataset()) | |
self._setup_server() | |
self.actions_per_chunk = 20 | |
self.actions_overlap = 10 | |
self.running = True # Add a running flag to control server lifetime | |
def _setup_server(self) -> None: | |
"""Flushes server state when new client connects.""" | |
# only running inference on the latest observation received by the server | |
self.observation_queue = Queue(maxsize=1) | |
def Ready(self, request, context): # noqa: N802 | |
client_id = context.peer() | |
logger.info(f"Client {client_id} connected and ready") | |
self._setup_server() | |
return async_inference_pb2.Empty() | |
def SendPolicyInstructions(self, request, context): # noqa: N802 | |
"""Receive policy instructions from the robot client""" | |
client_id = context.peer() | |
logger.debug(f"Receiving policy instructions from {client_id}") | |
policy_specs = pickle.loads(request.data) # nosec | |
assert isinstance(policy_specs, TinyPolicyConfig), ( | |
f"Policy specs must be a TinyPolicyConfig. Got {type(policy_specs)}" | |
) | |
logger.info( | |
f"Policy type: {policy_specs.policy_type} | " | |
f"Pretrained name or path: {policy_specs.pretrained_name_or_path} | " | |
f"Device: {policy_specs.device}" | |
) | |
assert policy_specs.policy_type in supported_policies, ( | |
f"Policy type {policy_specs.policy_type} not supported. Supported policies: {supported_policies}" | |
) | |
self.device = policy_specs.device | |
policy_class = get_policy_class(policy_specs.policy_type) | |
start = time.time() | |
self.policy = policy_class.from_pretrained(policy_specs.pretrained_name_or_path) | |
self.policy.to(self.device) | |
end = time.time() | |
logger.info(f"Time taken to put policy on {self.device}: {end - start:.4f} seconds") | |
return async_inference_pb2.Empty() | |
def SendObservations(self, request_iterator, context): # noqa: N802 | |
"""Receive observations from the robot client""" | |
client_id = context.peer() | |
logger.debug(f"Receiving observations from {client_id}") | |
for observation in request_iterator: | |
receive_time = time.time() | |
timed_observation = pickle.loads(observation.data) # nosec | |
deserialize_time = time.time() | |
# If queue is full, get the old observation to make room | |
if self.observation_queue.full(): | |
# pops from queue | |
_ = self.observation_queue.get_nowait() | |
logger.debug("Observation queue was full, removed oldest observation") | |
# Now put the new observation (never blocks as queue is non-full here) | |
self.observation_queue.put(timed_observation) | |
queue_time = time.time() | |
obs_timestep = timed_observation.get_timestep() | |
obs_timestamp = timed_observation.get_timestamp() | |
logger.info( | |
f"Received observation #{obs_timestep} | " | |
f"Client timestamp: {obs_timestamp:.6f} | " | |
f"Server timestamp: {receive_time:.6f} | " | |
f"Network latency: {receive_time - obs_timestamp:.6f}s | " | |
f"Deserialization time: {deserialize_time - receive_time:.6f}s | " | |
f"Queue time: {queue_time - deserialize_time:.6f}s" | |
) | |
return async_inference_pb2.Empty() | |
def StreamActions(self, request, context): # noqa: N802 | |
"""Stream actions to the robot client""" | |
client_id = context.peer() | |
logger.debug(f"Client {client_id} connected for action streaming") | |
# Generate action based on the most recent observation and its timestep | |
start_time = time.time() | |
try: | |
obs = self.observation_queue.get() | |
get_time = time.time() | |
logger.info( | |
f"Running inference for observation #{obs.get_timestep()} | Queue get time: {get_time - start_time:.6f}s" | |
) | |
if obs: | |
action = self._predict_action_chunk(obs) | |
inference_end_time = time.time() | |
logger.info( | |
f"Action chunk #{obs.get_timestep()} generated | " | |
f"Total inference time: {inference_end_time - get_time:.6f}s" | |
) | |
yield action | |
yield_time = time.time() | |
logger.info( | |
f"Action chunk #{obs.get_timestep()} sent | Send time: {yield_time - inference_end_time:.6f}s" | |
) | |
else: | |
logger.warning("No observation in queue yet!") | |
time.sleep(idle_wait) | |
except Exception as e: | |
logger.error(f"Error in StreamActions: {e}") | |
return async_inference_pb2.Empty() | |
def _time_action_chunk(self, t_0: float, action_chunk: list[torch.Tensor], i_0: int) -> list[TimedAction]: | |
"""Turn a chunk of actions into a list of TimedAction instances, | |
with the first action corresponding to t_0 and the rest corresponding to | |
t_0 + i*environment_dt for i in range(len(action_chunk)) | |
""" | |
return [ | |
TimedAction(t_0 + i * environment_dt, action, i_0 + i) for i, action in enumerate(action_chunk) | |
] | |
def _get_action_chunk(self, observation: dict[str, torch.Tensor]) -> torch.Tensor: | |
# NOTE: This temporary function only works for ACT policies (Pi0-like models are *not* supported just yet) | |
"""Get an action chunk from the policy""" | |
start_time = time.time() | |
# prepare observation for policy forward pass | |
batch = self.policy.normalize_inputs(observation) | |
normalize_time = time.time() | |
logger.debug(f"Observation normalization time: {normalize_time - start_time:.6f}s") | |
if self.policy.config.image_features: | |
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original | |
batch["observation.images"] = [batch[key] for key in self.policy.config.image_features] | |
prep_time = time.time() | |
logger.debug(f"Observation image preparation time: {prep_time - normalize_time:.6f}s") | |
# forward pass outputs up to policy.config.n_action_steps != actions_per_chunk | |
forward_start = time.time() | |
actions = self.policy.model(batch)[0][:, : self.actions_per_chunk] | |
forward_end = time.time() | |
logger.debug(f"Policy forward pass time: {forward_end - forward_start:.6f}s") | |
actions = self.policy.unnormalize_outputs({"action": actions})["action"] | |
unnormalize_end = time.time() | |
logger.debug(f"Action unnormalization time: {unnormalize_end - forward_end:.6f}s") | |
end_time = time.time() | |
logger.info(f"Action chunk generation total time: {end_time - start_time:.6f}s") | |
return actions | |
def _predict_action_chunk(self, observation_t: TimedObservation) -> list[TimedAction]: | |
"""Predict an action based on the observation""" | |
start_time = time.time() | |
observation = {} | |
for k, v in observation_t.get_observation().items(): | |
if "image" in k: | |
observation[k] = v.permute(2, 0, 1).unsqueeze(0).to(self.device) | |
else: | |
observation[k] = v.unsqueeze(0).to(self.device) | |
prep_time = time.time() | |
logger.debug(f"Observation preparation time: {prep_time - start_time:.6f}s") | |
# normalize observation | |
observation = self.policy.normalize_inputs(observation) | |
# Remove batch dimension | |
action_tensor = self._get_action_chunk(observation) | |
action_tensor = action_tensor.squeeze(0) | |
# Move to CPU before serializing | |
action_tensor = action_tensor.cpu() | |
post_inference_time = time.time() | |
logger.debug(f"Post-inference processing start: {post_inference_time - prep_time:.6f}s") | |
if action_tensor.dim() == 1: | |
# No chunk dimension, so repeat action to create a (dummy) chunk of actions | |
action_tensor = action_tensor.repeat(self.actions_per_chunk, 1) | |
action_chunk = self._time_action_chunk( | |
observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() | |
) | |
chunk_time = time.time() | |
logger.debug(f"Action chunk creation time: {chunk_time - post_inference_time:.6f}s") | |
action_bytes = pickle.dumps(action_chunk) # nosec | |
serialize_time = time.time() | |
logger.debug(f"Action serialization time: {serialize_time - chunk_time:.6f}s") | |
# Create and return the Action message | |
action = async_inference_pb2.Action(transfer_state=observation_t.transfer_state, data=action_bytes) | |
end_time = time.time() | |
logger.info( | |
f"Total action prediction time: {end_time - start_time:.6f}s | " | |
f"Observation #{observation_t.get_timestep()} | " | |
f"Action chunk size: {len(action_chunk)}" | |
) | |
time.sleep(inference_latency) | |
return action | |
def _stream_action_chunks_from_dataset(self) -> Generator[List[torch.Tensor], None, None]: | |
"""Stream chunks of actions from a prerecorded dataset. | |
Returns: | |
Generator that yields chunks of actions from the dataset | |
""" | |
dataset = load_dataset("fracapuano/so100_test", split="train").with_format("torch") | |
# 1. Select the action column only, where you will find tensors with 6 elements | |
actions = dataset["action"] | |
action_indices = torch.arange(len(actions)) | |
# 2. Chunk the iterable of tensors into chunks with 10 elements each | |
# sending only first element for debugging | |
indices_chunks = action_indices.unfold( | |
0, self.actions_per_chunk, self.actions_per_chunk - self.actions_overlap | |
) | |
for idx_chunk in indices_chunks: | |
yield actions[idx_chunk[0] : idx_chunk[-1] + 1, :] | |
def _read_action_chunk(self, observation: Optional[TimedObservation] = None): | |
"""Dummy function for predicting action chunk given observation. | |
Instead of computing actions on-the-fly, this method streams | |
actions from a prerecorded dataset. | |
""" | |
import warnings | |
warnings.warn( | |
"This method is deprecated and will be removed in the future.", DeprecationWarning, stacklevel=2 | |
) | |
if not observation: | |
observation = TimedObservation(timestamp=time.time(), observation={}, timestep=0) | |
transfer_state = 0 | |
else: | |
transfer_state = observation.transfer_state | |
# Get chunk of actions from the generator | |
actions_chunk = next(self.action_generator) | |
# Return a list of TimedActions, with timestamps starting from the observation timestamp | |
action_data = self._time_action_chunk( | |
observation.get_timestamp(), actions_chunk, observation.get_timestep() | |
) | |
action_bytes = pickle.dumps(action_data) # nosec | |
# Create and return the Action message | |
action = async_inference_pb2.Action(transfer_state=transfer_state, data=action_bytes) | |
time.sleep(inference_latency) # slow action generation, emulates inference time | |
return action | |
def stop(self): | |
"""Stop the server""" | |
self.running = False | |
logger.info("Server stopping...") | |
def serve(): | |
PORT = 8080 | |
# Create the server instance first | |
policy_server = PolicyServer() | |
# Setup and start gRPC server | |
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |
async_inference_pb2_grpc.add_AsyncInferenceServicer_to_server(policy_server, server) | |
server.add_insecure_port(f"[::]:{PORT}") | |
server.start() | |
logger.info(f"PolicyServer started on port {PORT}") | |
print(f"PolicyServer started on port {PORT}") | |
try: | |
# Use the running attribute to control server lifetime | |
while policy_server.running: | |
time.sleep(1) # Check every second instead of sleeping indefinitely | |
except KeyboardInterrupt: | |
policy_server.stop() | |
logger.info("Keyboard interrupt received") | |
finally: | |
server.stop(0) | |
logger.info("Server stopped") | |
if __name__ == "__main__": | |
serve() | |