Spaces:
Running
Running
import logging | |
import logging.handlers | |
import os | |
import pickle # nosec | |
import threading | |
import time | |
from queue import Empty, Queue | |
from typing import Any, Optional | |
import async_inference_pb2 # type: ignore | |
import async_inference_pb2_grpc # type: ignore | |
import grpc | |
import torch | |
from lerobot.common.robot_devices.robots.utils import make_robot | |
# Create logs directory if it doesn't exist | |
os.makedirs("logs", exist_ok=True) | |
# Set up logging with both console and file output | |
logger = logging.getLogger("robot_client") | |
logger.setLevel(logging.INFO) | |
# Console handler | |
console_handler = logging.StreamHandler() | |
console_handler.setFormatter( | |
logging.Formatter("%(asctime)s [CLIENT] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
) | |
logger.addHandler(console_handler) | |
# File handler - creates a new log file for each run | |
file_handler = logging.handlers.RotatingFileHandler( | |
f"logs/robot_client_{int(time.time())}.log", | |
maxBytes=10 * 1024 * 1024, # 10MB | |
backupCount=5, | |
) | |
file_handler.setFormatter( | |
logging.Formatter("%(asctime)s [CLIENT] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S") | |
) | |
logger.addHandler(file_handler) | |
environment_dt = 1 / 30 | |
idle_wait = 0.1 | |
class TimedData: | |
def __init__(self, timestamp: float, data: Any, timestep: int): | |
"""Initialize a TimedData object. | |
Args: | |
timestamp: Unix timestamp relative to data's creation. | |
data: The actual data to wrap a timestamp around. | |
""" | |
self.timestamp = timestamp | |
self.data = data | |
self.timestep = timestep | |
def get_data(self): | |
return self.data | |
def get_timestamp(self): | |
return self.timestamp | |
def get_timestep(self): | |
return self.timestep | |
class TimedAction(TimedData): | |
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int): | |
super().__init__(timestamp=timestamp, data=action, timestep=timestep) | |
def get_action(self): | |
return self.get_data() | |
class TimedObservation(TimedData): | |
def __init__( | |
self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0 | |
): | |
super().__init__(timestamp=timestamp, data=observation, timestep=timestep) | |
self.transfer_state = transfer_state | |
def get_observation(self): | |
return self.get_data() | |
class TinyPolicyConfig: | |
def __init__( | |
self, | |
policy_type: str = "act", | |
pretrained_name_or_path: str = "fracapuano/act_so100_test", | |
device: str = "cpu", | |
): | |
self.policy_type = policy_type | |
self.pretrained_name_or_path = pretrained_name_or_path | |
self.device = device | |
class RobotClient: | |
def __init__( | |
self, | |
server_address: Optional[str] = None, | |
policy_type: str = "act", # "pi0" | |
pretrained_name_or_path: str = "fracapuano/act_so100_test", # "lerobot/pi0" | |
policy_device: str = "mps", | |
): | |
# Use environment variable if server_address is not provided | |
if server_address is None: | |
server_address = os.getenv("SERVER_ADDRESS", "localhost:8080") | |
logger.info(f"No server address provided, using default address: {server_address}") | |
self.policy_config = TinyPolicyConfig(policy_type, pretrained_name_or_path, policy_device) | |
self.channel = grpc.insecure_channel(server_address) | |
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel) | |
logger.info(f"Initializing client to connect to server at {server_address}") | |
self.running = False | |
self.first_observation_sent = False | |
self.latest_action = 0 | |
self.action_chunk_size = 20 | |
self.action_queue = Queue() | |
self.start_barrier = threading.Barrier( | |
3 | |
) # 3 threads: observation sender, action receiver, action executor | |
# Create a lock for robot access | |
self.robot_lock = threading.Lock() | |
# Stats for logging | |
self.obs_sent_count = 0 | |
self.actions_received_count = 0 | |
self.actions_executed_count = 0 | |
self.last_obs_sent_time = 0 | |
self.last_action_received_time = 0 | |
start_time = time.time() | |
self.robot = make_robot("so100") | |
self.robot.connect() | |
connect_time = time.time() | |
logger.info(f"Robot connection time: {connect_time - start_time:.4f}s") | |
time.sleep(idle_wait) # sleep waiting for cameras to activate | |
logger.info("Robot connected and ready") | |
def timestamps(self): | |
"""Get the timestamps of the actions in the queue""" | |
return sorted([action.get_timestep() for action in self.action_queue.queue]) | |
def start(self): | |
"""Start the robot client and connect to the policy server""" | |
try: | |
# client-server handshake | |
start_time = time.time() | |
self.stub.Ready(async_inference_pb2.Empty()) | |
end_time = time.time() | |
logger.info(f"Connected to policy server in {end_time - start_time:.4f}s") | |
# send policy instructions | |
policy_config_bytes = pickle.dumps(self.policy_config) | |
policy_setup = async_inference_pb2.PolicySetup( | |
transfer_state=async_inference_pb2.TRANSFER_BEGIN, data=policy_config_bytes | |
) | |
logger.info("Sending policy instructions to policy server") | |
logger.info( | |
f"Policy type: {self.policy_config.policy_type} | " | |
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | " | |
f"Device: {self.policy_config.device}" | |
) | |
self.stub.SendPolicyInstructions(policy_setup) | |
self.running = True | |
return True | |
except grpc.RpcError as e: | |
logger.error(f"Failed to connect to policy server: {e}") | |
return False | |
def stop(self): | |
"""Stop the robot client""" | |
self.running = False | |
self.robot.disconnect() | |
logger.info("Robot disconnected") | |
self.channel.close() | |
logger.info("Client stopped, channel closed") | |
# Log final stats | |
logger.info( | |
f"Session stats - Observations sent: {self.obs_sent_count}, " | |
f"Action chunks received: {self.actions_received_count}, " | |
f"Actions executed: {self.actions_executed_count}" | |
) | |
def send_observation( | |
self, | |
obs: TimedObservation, | |
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE, | |
) -> bool: | |
"""Send observation to the policy server. | |
Returns True if the observation was sent successfully, False otherwise.""" | |
if not self.running: | |
logger.warning("Client not running") | |
return False | |
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!" | |
start_time = time.time() | |
observation_bytes = pickle.dumps(obs) | |
serialize_time = time.time() | |
logger.debug(f"Observation serialization time: {serialize_time - start_time:.6f}s") | |
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes) | |
try: | |
send_start = time.time() | |
_ = self.stub.SendObservations(iter([observation])) | |
send_end = time.time() | |
self.obs_sent_count += 1 | |
obs_timestep = obs.get_timestep() | |
logger.info( | |
f"Sent observation #{obs_timestep} | " | |
f"Serialize time: {serialize_time - start_time:.6f}s | " | |
f"Network time: {send_end - send_start:.6f}s | " | |
f"Total time: {send_end - start_time:.6f}s" | |
) | |
if transfer_state == async_inference_pb2.TRANSFER_BEGIN: | |
self.first_observation_sent = True | |
self.last_obs_sent_time = send_end | |
return True | |
except grpc.RpcError as e: | |
logger.error(f"Error sending observation #{obs.get_timestep()}: {e}") | |
return False | |
def _validate_action(self, action: TimedAction): | |
"""Received actions are keps only when they have been produced for now or later, never before""" | |
return not action.get_timestamp() < self.latest_action | |
def _validate_action_chunk(self, actions: list[TimedAction]): | |
assert len(actions) == self.action_chunk_size, ( | |
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}" | |
) | |
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk" | |
return True | |
def _inspect_action_queue(self): | |
queue_size = self.action_queue.qsize() | |
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue]) | |
logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}") | |
return queue_size, timestamps | |
def _clear_queue(self): | |
"""Clear the existing queue""" | |
start_time = time.time() | |
old_size = self.action_queue.qsize() | |
while not self.action_queue.empty(): | |
try: | |
self.action_queue.get_nowait() | |
except Empty: | |
break | |
end_time = time.time() | |
logger.debug(f"Queue cleared: {old_size} items removed in {end_time - start_time:.6f}s") | |
def _fill_action_queue(self, actions: list[TimedAction]): | |
"""Fill the action queue with incoming valid actions""" | |
start_time = time.time() | |
valid_count = 0 | |
for action in actions: | |
if self._validate_action(action): | |
self.action_queue.put(action) | |
valid_count += 1 | |
end_time = time.time() | |
logger.debug( | |
f"Queue filled: {valid_count}/{len(actions)} valid actions added in {end_time - start_time:.6f}s" | |
) | |
def _clear_and_fill_action_queue(self, actions: list[TimedAction]): | |
"""Clear the existing queue and fill it with new actions. | |
This is a higher-level function that combines clearing and filling operations. | |
Args: | |
actions: List of TimedAction instances to queue | |
""" | |
start_time = time.time() | |
logger.info(f"Current latest action: {self.latest_action}") | |
# Get queue state before changes | |
old_size, old_timesteps = self._inspect_action_queue() | |
# Log incoming actions | |
incoming_timesteps = [a.get_timestep() for a in actions] | |
logger.info(f"Incoming actions: {len(actions)} items with timesteps {incoming_timesteps}") | |
# Clear and fill | |
clear_start = time.time() | |
self._clear_queue() | |
clear_end = time.time() | |
fill_start = time.time() | |
self._fill_action_queue(actions) | |
fill_end = time.time() | |
# Get queue state after changes | |
new_size, new_timesteps = self._inspect_action_queue() | |
end_time = time.time() | |
logger.info( | |
f"Queue update complete | " | |
f"Before: {old_size} items | " | |
f"After: {new_size} items | " | |
f"Previous content: {old_timesteps} | " | |
f"Incoming content: {incoming_timesteps} | " | |
f"Current contents: {new_timesteps}" | |
) | |
logger.info( | |
f"Clear time: {clear_end - clear_start:.6f}s | " | |
f"Fill time: {fill_end - fill_start:.6f}s | " | |
f"Total time: {end_time - start_time:.6f}s" | |
) | |
def receive_actions(self): | |
"""Receive actions from the policy server""" | |
# Wait at barrier for synchronized start | |
self.start_barrier.wait() | |
logger.info("Action receiving thread starting") | |
while self.running: | |
try: | |
# Use StreamActions to get a stream of actions from the server | |
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()): | |
receive_time = time.time() | |
# Deserialize bytes back into list[TimedAction] | |
deserialize_start = time.time() | |
timed_actions = pickle.loads(actions_chunk.data) # nosec | |
deserialize_end = time.time() | |
# Calculate network latency if we have matching observations | |
if len(timed_actions) > 0: | |
first_action_timestep = timed_actions[0].get_timestep() | |
server_to_client_latency = receive_time - self.last_obs_sent_time | |
logger.info( | |
f"Received action chunk for step #{first_action_timestep} | " | |
f"Network latency (server->client): {server_to_client_latency:.6f}s | " | |
f"Deserialization time: {deserialize_end - deserialize_start:.6f}s" | |
) | |
# Update action queue | |
_ = time.time() | |
self._clear_and_fill_action_queue(timed_actions) | |
update_end = time.time() | |
self.actions_received_count += 1 | |
self.last_action_received_time = receive_time | |
logger.info( | |
f"Action chunk processed | " | |
f"Total processing time: {update_end - receive_time:.6f}s | " | |
f"Round-trip time since observation sent: {receive_time - self.last_obs_sent_time:.6f}s" | |
) | |
except grpc.RpcError as e: | |
logger.error(f"Error receiving actions: {e}") | |
time.sleep(idle_wait) # Avoid tight loop on error | |
def _get_next_action(self) -> Optional[TimedAction]: | |
"""Get the next action from the queue""" | |
try: | |
action = self.action_queue.get_nowait() | |
return action | |
except Empty: | |
return None | |
def execute_actions(self): | |
"""Continuously execute actions from the queue""" | |
# Wait at barrier for synchronized start | |
self.start_barrier.wait() | |
logger.info("Action execution thread starting") | |
while self.running: | |
# Get the next action from the queue | |
cycle_start = time.time() | |
time.sleep(environment_dt) | |
get_start = time.time() | |
timed_action = self._get_next_action() | |
get_end = time.time() | |
if timed_action is not None: | |
# self.latest_action = timed_action.get_timestep() | |
_ = self.latest_action | |
self.latest_action = timed_action.get_timestamp() | |
action_timestep = timed_action.get_timestep() | |
# Convert action to tensor and send to robot - Acquire lock before accessing the robot | |
lock_start = time.time() | |
if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock | |
lock_acquired = time.time() | |
try: | |
send_start = time.time() | |
self.robot.send_action(timed_action.get_action()) | |
send_end = time.time() | |
self.actions_executed_count += 1 | |
logger.info( | |
f"Executed action #{action_timestep} | " | |
f"Queue get time: {get_end - get_start:.6f}s | " | |
f"Lock wait time: {lock_acquired - lock_start:.6f}s | " | |
f"Action send time: {send_end - send_start:.6f}s | " | |
f"Total execution time: {send_end - cycle_start:.6f}s | " | |
f"Action latency: {send_end - timed_action.get_timestamp():.6f}s | " | |
f"Queue size: {self.action_queue.qsize()}" | |
) | |
finally: | |
# Always release the lock in a finally block to ensure it's released | |
self.robot_lock.release() | |
else: | |
logger.warning("Could not acquire robot lock for action execution, retrying next cycle") | |
else: | |
if get_end - get_start > 0.001: # Only log if there was a measurable delay | |
logger.debug(f"No action available, get time: {get_end - get_start:.6f}s") | |
time.sleep(idle_wait) | |
def stream_observations(self, get_observation_fn): | |
"""Continuously stream observations to the server""" | |
# Wait at barrier for synchronized start | |
self.start_barrier.wait() | |
logger.info("Observation streaming thread starting") | |
first_observation = True | |
while self.running: | |
try: | |
# Get serialized observation bytes from the function | |
cycle_start = time.time() | |
time.sleep(environment_dt) | |
get_start = time.time() | |
observation = get_observation_fn() | |
get_end = time.time() | |
# Skip if observation is None (couldn't acquire lock) | |
if observation is None: | |
logger.warning("Failed to get observation, skipping cycle") | |
continue | |
# Set appropriate transfer state | |
if first_observation: | |
state = async_inference_pb2.TRANSFER_BEGIN | |
first_observation = False | |
else: | |
state = async_inference_pb2.TRANSFER_MIDDLE | |
obs_timestep = observation.get_timestep() | |
logger.debug(f"Got observation #{obs_timestep} in {get_end - get_start:.6f}s, sending...") | |
send_start = time.time() | |
self.send_observation(observation, state) | |
send_end = time.time() | |
logger.info( | |
f"Observation #{obs_timestep} cycle complete | " | |
f"Get time: {get_end - get_start:.6f}s | " | |
f"Send time: {send_end - send_start:.6f}s | " | |
f"Total cycle time: {send_end - cycle_start:.6f}s" | |
) | |
except Exception as e: | |
logger.error(f"Error in observation sender: {e}") | |
time.sleep(idle_wait) | |
def async_client(): | |
# Example of how to use the RobotClient | |
client = RobotClient() | |
if client.start(): | |
# Function to generate mock observations | |
def get_observation(): | |
# Create a counter attribute if it doesn't exist | |
if not hasattr(get_observation, "counter"): | |
get_observation.counter = 0 | |
# Acquire lock before accessing the robot | |
start_time = time.time() | |
observation_content = None | |
if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock | |
lock_time = time.time() | |
try: | |
capture_start = time.time() | |
observation_content = client.robot.capture_observation() | |
capture_end = time.time() | |
logger.debug( | |
f"Observation capture | " | |
f"Lock acquisition: {lock_time - start_time:.6f}s | " | |
f"Capture time: {capture_end - capture_start:.6f}s" | |
) | |
finally: | |
# Always release the lock in a finally block to ensure it's released | |
client.robot_lock.release() | |
else: | |
logger.warning("Could not acquire robot lock for observation capture, skipping this cycle") | |
return None # Return None to indicate no observation was captured | |
current_time = time.time() | |
observation = TimedObservation( | |
timestamp=current_time, observation=observation_content, timestep=get_observation.counter | |
) | |
# Increment counter for next call | |
get_observation.counter += 1 | |
end_time = time.time() | |
logger.debug( | |
f"Observation #{observation.get_timestep()} prepared | " | |
f"Total time: {end_time - start_time:.6f}s" | |
) | |
return observation | |
logger.info("Starting all threads...") | |
# Create and start observation sender thread | |
obs_thread = threading.Thread(target=client.stream_observations, args=(get_observation,)) | |
obs_thread.daemon = True | |
# Create and start action receiver thread | |
action_receiver_thread = threading.Thread(target=client.receive_actions) | |
action_receiver_thread.daemon = True | |
# Create action execution thread | |
action_execution_thread = threading.Thread(target=client.execute_actions) | |
action_execution_thread.daemon = True | |
# Start all threads | |
obs_thread.start() | |
action_receiver_thread.start() | |
action_execution_thread.start() | |
try: | |
# Main thread just keeps everything alive | |
while client.running: | |
time.sleep(idle_wait) | |
except KeyboardInterrupt: | |
pass | |
finally: | |
client.stop() | |
logger.info("Client stopped") | |
if __name__ == "__main__": | |
async_client() | |