|
|
import base64 |
|
|
import json |
|
|
import os |
|
|
from io import BytesIO |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
from openpi.policies import policy_config |
|
|
from openpi.training import config as train_config |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
Initialize the handler for pi0 model inference using openpi infrastructure. |
|
|
|
|
|
Args: |
|
|
path: Path to the model weights directory |
|
|
""" |
|
|
|
|
|
model_path = os.environ.get("MODEL_PATH", path) |
|
|
if not model_path: |
|
|
model_path = "weights/pi0" |
|
|
|
|
|
|
|
|
config_path = os.path.join(model_path, "config.json") |
|
|
with open(config_path, "r") as f: |
|
|
model_config = json.load(f) |
|
|
|
|
|
model_type = model_config.get("type", "pi0") |
|
|
|
|
|
|
|
|
|
|
|
if model_type == "pi0": |
|
|
self.train_config = train_config.get_config("pi0") |
|
|
else: |
|
|
|
|
|
self.train_config = train_config.get_config("pi0") |
|
|
|
|
|
|
|
|
|
|
|
self.policy = policy_config.create_trained_policy( |
|
|
self.train_config, |
|
|
model_path, |
|
|
pytorch_device="cuda" if os.environ.get("CUDA_VISIBLE_DEVICES") else "cpu" |
|
|
) |
|
|
|
|
|
|
|
|
self.default_num_steps = 50 |
|
|
|
|
|
def _decode_base64_image(self, base64_str: str) -> np.ndarray: |
|
|
""" |
|
|
Decode base64 image string to numpy array. |
|
|
|
|
|
Args: |
|
|
base64_str: Base64 encoded image string |
|
|
|
|
|
Returns: |
|
|
numpy array of shape (H, W, 3) with values in [0, 255] |
|
|
""" |
|
|
|
|
|
if base64_str.startswith("data:image"): |
|
|
base64_str = base64_str.split(",", 1)[1] |
|
|
|
|
|
|
|
|
image_bytes = base64.b64decode(base64_str) |
|
|
|
|
|
|
|
|
image = Image.open(BytesIO(image_bytes)).convert("RGB") |
|
|
image_array = np.array(image) |
|
|
|
|
|
return image_array |
|
|
|
|
|
def _prepare_observation(self, images: Dict[str, str], state: List[float], prompt: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Prepare observation dictionary in the format expected by openpi. |
|
|
|
|
|
Args: |
|
|
images: Dictionary mapping camera names to base64 encoded images |
|
|
state: List of robot state values |
|
|
prompt: Optional text prompt |
|
|
|
|
|
Returns: |
|
|
Observation dictionary in openpi format |
|
|
""" |
|
|
|
|
|
processed_images = {} |
|
|
|
|
|
|
|
|
|
|
|
camera_mapping = { |
|
|
"camera0": "cam_high", |
|
|
"camera1": "cam_left_wrist", |
|
|
"camera2": "cam_right_wrist", |
|
|
|
|
|
"base_camera": "cam_high", |
|
|
"left_wrist": "cam_left_wrist", |
|
|
"right_wrist": "cam_right_wrist", |
|
|
|
|
|
"cam_high": "cam_high", |
|
|
"cam_left_wrist": "cam_left_wrist", |
|
|
"cam_right_wrist": "cam_right_wrist" |
|
|
} |
|
|
|
|
|
for input_name, image_b64 in images.items(): |
|
|
|
|
|
openpi_name = camera_mapping.get(input_name, input_name) |
|
|
|
|
|
|
|
|
image_array = self._decode_base64_image(image_b64) |
|
|
|
|
|
|
|
|
if image_array.shape[:2] != (224, 224): |
|
|
image_pil = Image.fromarray(image_array) |
|
|
image_resized = image_pil.resize((224, 224)) |
|
|
image_array = np.array(image_resized) |
|
|
|
|
|
|
|
|
processed_images[openpi_name] = image_array.astype(np.uint8) |
|
|
|
|
|
|
|
|
required_cameras = ["cam_high", "cam_left_wrist", "cam_right_wrist"] |
|
|
for cam_name in required_cameras: |
|
|
if cam_name not in processed_images: |
|
|
|
|
|
processed_images[cam_name] = np.zeros((224, 224, 3), dtype=np.uint8) |
|
|
|
|
|
|
|
|
state_array = np.array(state, dtype=np.float32) |
|
|
|
|
|
|
|
|
observation = { |
|
|
"state": state_array, |
|
|
"images": processed_images, |
|
|
} |
|
|
|
|
|
|
|
|
if prompt: |
|
|
observation["prompt"] = prompt |
|
|
|
|
|
return observation |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
Main inference function called by HuggingFace endpoint. |
|
|
|
|
|
Args: |
|
|
data: Input data dictionary containing: |
|
|
- inputs: Dictionary with: |
|
|
- images: Dict mapping camera names to base64 encoded images |
|
|
- state: List of robot state values |
|
|
- prompt: Optional text prompt |
|
|
- num_actions: Optional, number of actions to predict (default: 50) |
|
|
- noise: Optional, noise array for sampling |
|
|
|
|
|
Returns: |
|
|
List containing prediction results |
|
|
""" |
|
|
try: |
|
|
inputs = data.get("inputs", {}) |
|
|
|
|
|
|
|
|
images = inputs.get("images", {}) |
|
|
state = inputs.get("state", []) |
|
|
prompt = inputs.get("prompt", "") |
|
|
num_actions = inputs.get("num_actions", self.default_num_steps) |
|
|
noise_input = inputs.get("noise", None) |
|
|
|
|
|
|
|
|
if not images: |
|
|
raise ValueError("No images provided") |
|
|
if not state: |
|
|
raise ValueError("No state provided") |
|
|
|
|
|
|
|
|
observation = self._prepare_observation(images, state, prompt) |
|
|
|
|
|
|
|
|
noise = None |
|
|
if noise_input is not None: |
|
|
noise = np.array(noise_input, dtype=np.float32) |
|
|
|
|
|
|
|
|
|
|
|
result = self.policy.infer(observation, noise=noise) |
|
|
|
|
|
|
|
|
actions = result["actions"] |
|
|
|
|
|
|
|
|
if isinstance(actions, np.ndarray): |
|
|
actions_list = actions.tolist() |
|
|
else: |
|
|
actions_list = actions |
|
|
|
|
|
|
|
|
return [{ |
|
|
"actions": actions_list, |
|
|
"num_actions": len(actions_list), |
|
|
"action_horizon": len(actions_list), |
|
|
"action_dim": len(actions_list[0]) if actions_list else 0, |
|
|
"success": True, |
|
|
"metadata": { |
|
|
"model_type": self.train_config.model.model_type.value, |
|
|
"policy_metadata": getattr(self.policy, '_metadata', {}) |
|
|
} |
|
|
}] |
|
|
|
|
|
except Exception as e: |
|
|
return [{ |
|
|
"error": str(e), |
|
|
"success": False |
|
|
}] |