Francesco Capuano
fix: ports ports ports
78df986
raw
history blame contribute delete
21.4 kB
import logging
import logging.handlers
import os
import pickle # nosec
import threading
import time
from queue import Empty, Queue
from typing import Any, Optional
import async_inference_pb2 # type: ignore
import async_inference_pb2_grpc # type: ignore
import grpc
import torch
from lerobot.common.robot_devices.robots.utils import make_robot
# Create logs directory if it doesn't exist
os.makedirs("logs", exist_ok=True)
# Set up logging with both console and file output
logger = logging.getLogger("robot_client")
logger.setLevel(logging.INFO)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setFormatter(
logging.Formatter("%(asctime)s [CLIENT] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
)
logger.addHandler(console_handler)
# File handler - creates a new log file for each run
file_handler = logging.handlers.RotatingFileHandler(
f"logs/robot_client_{int(time.time())}.log",
maxBytes=10 * 1024 * 1024, # 10MB
backupCount=5,
)
file_handler.setFormatter(
logging.Formatter("%(asctime)s [CLIENT] [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
)
logger.addHandler(file_handler)
environment_dt = 1 / 30
idle_wait = 0.1
class TimedData:
def __init__(self, timestamp: float, data: Any, timestep: int):
"""Initialize a TimedData object.
Args:
timestamp: Unix timestamp relative to data's creation.
data: The actual data to wrap a timestamp around.
"""
self.timestamp = timestamp
self.data = data
self.timestep = timestep
def get_data(self):
return self.data
def get_timestamp(self):
return self.timestamp
def get_timestep(self):
return self.timestep
class TimedAction(TimedData):
def __init__(self, timestamp: float, action: torch.Tensor, timestep: int):
super().__init__(timestamp=timestamp, data=action, timestep=timestep)
def get_action(self):
return self.get_data()
class TimedObservation(TimedData):
def __init__(
self, timestamp: float, observation: dict[str, torch.Tensor], timestep: int, transfer_state: int = 0
):
super().__init__(timestamp=timestamp, data=observation, timestep=timestep)
self.transfer_state = transfer_state
def get_observation(self):
return self.get_data()
class TinyPolicyConfig:
def __init__(
self,
policy_type: str = "act",
pretrained_name_or_path: str = "fracapuano/act_so100_test",
device: str = "cpu",
):
self.policy_type = policy_type
self.pretrained_name_or_path = pretrained_name_or_path
self.device = device
class RobotClient:
def __init__(
self,
server_address: Optional[str] = None,
policy_type: str = "act", # "pi0"
pretrained_name_or_path: str = "fracapuano/act_so100_test", # "lerobot/pi0"
policy_device: str = "mps",
):
# Use environment variable if server_address is not provided
if server_address is None:
server_address = os.getenv("SERVER_ADDRESS", "localhost:8080")
logger.info(f"No server address provided, using default address: {server_address}")
self.policy_config = TinyPolicyConfig(policy_type, pretrained_name_or_path, policy_device)
self.channel = grpc.insecure_channel(server_address)
self.stub = async_inference_pb2_grpc.AsyncInferenceStub(self.channel)
logger.info(f"Initializing client to connect to server at {server_address}")
self.running = False
self.first_observation_sent = False
self.latest_action = 0
self.action_chunk_size = 20
self.action_queue = Queue()
self.start_barrier = threading.Barrier(
3
) # 3 threads: observation sender, action receiver, action executor
# Create a lock for robot access
self.robot_lock = threading.Lock()
# Stats for logging
self.obs_sent_count = 0
self.actions_received_count = 0
self.actions_executed_count = 0
self.last_obs_sent_time = 0
self.last_action_received_time = 0
start_time = time.time()
self.robot = make_robot("so100")
self.robot.connect()
connect_time = time.time()
logger.info(f"Robot connection time: {connect_time - start_time:.4f}s")
time.sleep(idle_wait) # sleep waiting for cameras to activate
logger.info("Robot connected and ready")
def timestamps(self):
"""Get the timestamps of the actions in the queue"""
return sorted([action.get_timestep() for action in self.action_queue.queue])
def start(self):
"""Start the robot client and connect to the policy server"""
try:
# client-server handshake
start_time = time.time()
self.stub.Ready(async_inference_pb2.Empty())
end_time = time.time()
logger.info(f"Connected to policy server in {end_time - start_time:.4f}s")
# send policy instructions
policy_config_bytes = pickle.dumps(self.policy_config)
policy_setup = async_inference_pb2.PolicySetup(
transfer_state=async_inference_pb2.TRANSFER_BEGIN, data=policy_config_bytes
)
logger.info("Sending policy instructions to policy server")
logger.info(
f"Policy type: {self.policy_config.policy_type} | "
f"Pretrained name or path: {self.policy_config.pretrained_name_or_path} | "
f"Device: {self.policy_config.device}"
)
self.stub.SendPolicyInstructions(policy_setup)
self.running = True
return True
except grpc.RpcError as e:
logger.error(f"Failed to connect to policy server: {e}")
return False
def stop(self):
"""Stop the robot client"""
self.running = False
self.robot.disconnect()
logger.info("Robot disconnected")
self.channel.close()
logger.info("Client stopped, channel closed")
# Log final stats
logger.info(
f"Session stats - Observations sent: {self.obs_sent_count}, "
f"Action chunks received: {self.actions_received_count}, "
f"Actions executed: {self.actions_executed_count}"
)
def send_observation(
self,
obs: TimedObservation,
transfer_state: async_inference_pb2.TransferState = async_inference_pb2.TRANSFER_MIDDLE,
) -> bool:
"""Send observation to the policy server.
Returns True if the observation was sent successfully, False otherwise."""
if not self.running:
logger.warning("Client not running")
return False
assert isinstance(obs, TimedObservation), "Input observation needs to be a TimedObservation!"
start_time = time.time()
observation_bytes = pickle.dumps(obs)
serialize_time = time.time()
logger.debug(f"Observation serialization time: {serialize_time - start_time:.6f}s")
observation = async_inference_pb2.Observation(transfer_state=transfer_state, data=observation_bytes)
try:
send_start = time.time()
_ = self.stub.SendObservations(iter([observation]))
send_end = time.time()
self.obs_sent_count += 1
obs_timestep = obs.get_timestep()
logger.info(
f"Sent observation #{obs_timestep} | "
f"Serialize time: {serialize_time - start_time:.6f}s | "
f"Network time: {send_end - send_start:.6f}s | "
f"Total time: {send_end - start_time:.6f}s"
)
if transfer_state == async_inference_pb2.TRANSFER_BEGIN:
self.first_observation_sent = True
self.last_obs_sent_time = send_end
return True
except grpc.RpcError as e:
logger.error(f"Error sending observation #{obs.get_timestep()}: {e}")
return False
def _validate_action(self, action: TimedAction):
"""Received actions are keps only when they have been produced for now or later, never before"""
return not action.get_timestamp() < self.latest_action
def _validate_action_chunk(self, actions: list[TimedAction]):
assert len(actions) == self.action_chunk_size, (
f"Action batch size must match action chunk!size: {len(actions)} != {self.action_chunk_size}"
)
assert all(self._validate_action(action) for action in actions), "Invalid action in chunk"
return True
def _inspect_action_queue(self):
queue_size = self.action_queue.qsize()
timestamps = sorted([action.get_timestep() for action in self.action_queue.queue])
logger.debug(f"Queue size: {queue_size}, Queue contents: {timestamps}")
return queue_size, timestamps
def _clear_queue(self):
"""Clear the existing queue"""
start_time = time.time()
old_size = self.action_queue.qsize()
while not self.action_queue.empty():
try:
self.action_queue.get_nowait()
except Empty:
break
end_time = time.time()
logger.debug(f"Queue cleared: {old_size} items removed in {end_time - start_time:.6f}s")
def _fill_action_queue(self, actions: list[TimedAction]):
"""Fill the action queue with incoming valid actions"""
start_time = time.time()
valid_count = 0
for action in actions:
if self._validate_action(action):
self.action_queue.put(action)
valid_count += 1
end_time = time.time()
logger.debug(
f"Queue filled: {valid_count}/{len(actions)} valid actions added in {end_time - start_time:.6f}s"
)
def _clear_and_fill_action_queue(self, actions: list[TimedAction]):
"""Clear the existing queue and fill it with new actions.
This is a higher-level function that combines clearing and filling operations.
Args:
actions: List of TimedAction instances to queue
"""
start_time = time.time()
logger.info(f"Current latest action: {self.latest_action}")
# Get queue state before changes
old_size, old_timesteps = self._inspect_action_queue()
# Log incoming actions
incoming_timesteps = [a.get_timestep() for a in actions]
logger.info(f"Incoming actions: {len(actions)} items with timesteps {incoming_timesteps}")
# Clear and fill
clear_start = time.time()
self._clear_queue()
clear_end = time.time()
fill_start = time.time()
self._fill_action_queue(actions)
fill_end = time.time()
# Get queue state after changes
new_size, new_timesteps = self._inspect_action_queue()
end_time = time.time()
logger.info(
f"Queue update complete | "
f"Before: {old_size} items | "
f"After: {new_size} items | "
f"Previous content: {old_timesteps} | "
f"Incoming content: {incoming_timesteps} | "
f"Current contents: {new_timesteps}"
)
logger.info(
f"Clear time: {clear_end - clear_start:.6f}s | "
f"Fill time: {fill_end - fill_start:.6f}s | "
f"Total time: {end_time - start_time:.6f}s"
)
def receive_actions(self):
"""Receive actions from the policy server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
logger.info("Action receiving thread starting")
while self.running:
try:
# Use StreamActions to get a stream of actions from the server
for actions_chunk in self.stub.StreamActions(async_inference_pb2.Empty()):
receive_time = time.time()
# Deserialize bytes back into list[TimedAction]
deserialize_start = time.time()
timed_actions = pickle.loads(actions_chunk.data) # nosec
deserialize_end = time.time()
# Calculate network latency if we have matching observations
if len(timed_actions) > 0:
first_action_timestep = timed_actions[0].get_timestep()
server_to_client_latency = receive_time - self.last_obs_sent_time
logger.info(
f"Received action chunk for step #{first_action_timestep} | "
f"Network latency (server->client): {server_to_client_latency:.6f}s | "
f"Deserialization time: {deserialize_end - deserialize_start:.6f}s"
)
# Update action queue
_ = time.time()
self._clear_and_fill_action_queue(timed_actions)
update_end = time.time()
self.actions_received_count += 1
self.last_action_received_time = receive_time
logger.info(
f"Action chunk processed | "
f"Total processing time: {update_end - receive_time:.6f}s | "
f"Round-trip time since observation sent: {receive_time - self.last_obs_sent_time:.6f}s"
)
except grpc.RpcError as e:
logger.error(f"Error receiving actions: {e}")
time.sleep(idle_wait) # Avoid tight loop on error
def _get_next_action(self) -> Optional[TimedAction]:
"""Get the next action from the queue"""
try:
action = self.action_queue.get_nowait()
return action
except Empty:
return None
def execute_actions(self):
"""Continuously execute actions from the queue"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
logger.info("Action execution thread starting")
while self.running:
# Get the next action from the queue
cycle_start = time.time()
time.sleep(environment_dt)
get_start = time.time()
timed_action = self._get_next_action()
get_end = time.time()
if timed_action is not None:
# self.latest_action = timed_action.get_timestep()
_ = self.latest_action
self.latest_action = timed_action.get_timestamp()
action_timestep = timed_action.get_timestep()
# Convert action to tensor and send to robot - Acquire lock before accessing the robot
lock_start = time.time()
if self.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
lock_acquired = time.time()
try:
send_start = time.time()
self.robot.send_action(timed_action.get_action())
send_end = time.time()
self.actions_executed_count += 1
logger.info(
f"Executed action #{action_timestep} | "
f"Queue get time: {get_end - get_start:.6f}s | "
f"Lock wait time: {lock_acquired - lock_start:.6f}s | "
f"Action send time: {send_end - send_start:.6f}s | "
f"Total execution time: {send_end - cycle_start:.6f}s | "
f"Action latency: {send_end - timed_action.get_timestamp():.6f}s | "
f"Queue size: {self.action_queue.qsize()}"
)
finally:
# Always release the lock in a finally block to ensure it's released
self.robot_lock.release()
else:
logger.warning("Could not acquire robot lock for action execution, retrying next cycle")
else:
if get_end - get_start > 0.001: # Only log if there was a measurable delay
logger.debug(f"No action available, get time: {get_end - get_start:.6f}s")
time.sleep(idle_wait)
def stream_observations(self, get_observation_fn):
"""Continuously stream observations to the server"""
# Wait at barrier for synchronized start
self.start_barrier.wait()
logger.info("Observation streaming thread starting")
first_observation = True
while self.running:
try:
# Get serialized observation bytes from the function
cycle_start = time.time()
time.sleep(environment_dt)
get_start = time.time()
observation = get_observation_fn()
get_end = time.time()
# Skip if observation is None (couldn't acquire lock)
if observation is None:
logger.warning("Failed to get observation, skipping cycle")
continue
# Set appropriate transfer state
if first_observation:
state = async_inference_pb2.TRANSFER_BEGIN
first_observation = False
else:
state = async_inference_pb2.TRANSFER_MIDDLE
obs_timestep = observation.get_timestep()
logger.debug(f"Got observation #{obs_timestep} in {get_end - get_start:.6f}s, sending...")
send_start = time.time()
self.send_observation(observation, state)
send_end = time.time()
logger.info(
f"Observation #{obs_timestep} cycle complete | "
f"Get time: {get_end - get_start:.6f}s | "
f"Send time: {send_end - send_start:.6f}s | "
f"Total cycle time: {send_end - cycle_start:.6f}s"
)
except Exception as e:
logger.error(f"Error in observation sender: {e}")
time.sleep(idle_wait)
def async_client():
# Example of how to use the RobotClient
client = RobotClient()
if client.start():
# Function to generate mock observations
def get_observation():
# Create a counter attribute if it doesn't exist
if not hasattr(get_observation, "counter"):
get_observation.counter = 0
# Acquire lock before accessing the robot
start_time = time.time()
observation_content = None
if client.robot_lock.acquire(timeout=1.0): # Wait up to 1 second to acquire the lock
lock_time = time.time()
try:
capture_start = time.time()
observation_content = client.robot.capture_observation()
capture_end = time.time()
logger.debug(
f"Observation capture | "
f"Lock acquisition: {lock_time - start_time:.6f}s | "
f"Capture time: {capture_end - capture_start:.6f}s"
)
finally:
# Always release the lock in a finally block to ensure it's released
client.robot_lock.release()
else:
logger.warning("Could not acquire robot lock for observation capture, skipping this cycle")
return None # Return None to indicate no observation was captured
current_time = time.time()
observation = TimedObservation(
timestamp=current_time, observation=observation_content, timestep=get_observation.counter
)
# Increment counter for next call
get_observation.counter += 1
end_time = time.time()
logger.debug(
f"Observation #{observation.get_timestep()} prepared | "
f"Total time: {end_time - start_time:.6f}s"
)
return observation
logger.info("Starting all threads...")
# Create and start observation sender thread
obs_thread = threading.Thread(target=client.stream_observations, args=(get_observation,))
obs_thread.daemon = True
# Create and start action receiver thread
action_receiver_thread = threading.Thread(target=client.receive_actions)
action_receiver_thread.daemon = True
# Create action execution thread
action_execution_thread = threading.Thread(target=client.execute_actions)
action_execution_thread.daemon = True
# Start all threads
obs_thread.start()
action_receiver_thread.start()
action_execution_thread.start()
try:
# Main thread just keeps everything alive
while client.running:
time.sleep(idle_wait)
except KeyboardInterrupt:
pass
finally:
client.stop()
logger.info("Client stopped")
if __name__ == "__main__":
async_client()