blanchon's picture
Update
3380376
import asyncio
import logging
from abc import ABC, abstractmethod
import numpy as np
import torch
from PIL import Image
from .joint_config import JointConfig
logger = logging.getLogger(__name__)
class BaseInferenceEngine(ABC):
"""
Base class for all inference engines.
This class provides common functionality for:
- Image preprocessing and normalization
- Joint data handling and validation
- Model loading and management
- Action prediction interface
"""
def __init__(
self,
policy_path: str,
camera_names: list[str],
device: str | None = None,
):
self.policy_path = policy_path
self.camera_names = camera_names
# Device selection
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
logger.info(f"Using device: {self.device}")
# Model and preprocessing
self.policy = None
self.image_transforms = {} # {camera_name: transform}
self.stats = None # Dataset statistics for normalization
# State tracking
self.is_loaded = False
self.last_images = {}
self.last_joint_positions = None
@abstractmethod
async def load_policy(self):
"""Load the policy model. Must be implemented by subclasses."""
@abstractmethod
async def predict(
self, images: dict[str, np.ndarray], joint_positions: np.ndarray, **kwargs
) -> np.ndarray:
"""Run inference. Must be implemented by subclasses."""
def preprocess_images(
self, images: dict[str, np.ndarray]
) -> dict[str, torch.Tensor]:
"""
Preprocess images for inference.
Args:
images: Dictionary of {camera_name: rgb_image_array}
Returns:
Dictionary of {camera_name: preprocessed_tensor}
"""
processed_images = {}
for camera_name, image in images.items():
if camera_name not in self.camera_names:
logger.warning(f"Unexpected camera: {camera_name}")
continue
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
if image.dtype != np.uint8:
image = (image * 255).astype(np.uint8)
pil_image = Image.fromarray(image)
else:
pil_image = image
# Apply transforms if available
if camera_name in self.image_transforms:
tensor = self.image_transforms[camera_name](pil_image)
else:
# Default preprocessing: resize to 224x224 and normalize
tensor = self._default_image_transform(pil_image)
processed_images[camera_name] = tensor.to(self.device)
return processed_images
def _default_image_transform(self, image: Image.Image) -> torch.Tensor:
"""Default image preprocessing."""
# Resize to 224x224 (common size for vision models)
image = image.resize((224, 224), Image.Resampling.LANCZOS)
# Convert to tensor and normalize to [0, 1]
tensor = torch.from_numpy(np.array(image)).float() / 255.0
# Rearrange from HWC to CHW
if len(tensor.shape) == 3:
tensor = tensor.permute(2, 0, 1)
return tensor
def preprocess_joint_positions(self, joint_positions: np.ndarray) -> torch.Tensor:
"""
Preprocess joint positions for inference.
Args:
joint_positions: Array of joint positions in standard order
Returns:
Preprocessed joint tensor
"""
# Validate and clamp joint values
joint_positions = JointConfig.validate_joint_values(joint_positions)
# Convert to tensor
joint_tensor = torch.from_numpy(joint_positions).float().to(self.device)
# Normalize if we have dataset statistics
if self.stats and hasattr(self.stats, "joint_stats"):
joint_tensor = self._normalize_joints(joint_tensor)
return joint_tensor
def _normalize_joints(self, joint_tensor: torch.Tensor) -> torch.Tensor:
"""Normalize joint values using dataset statistics."""
# This would use the actual dataset statistics
# For now, we assume joints are already normalized
return joint_tensor
def get_joint_commands_with_names(self, action: np.ndarray) -> list[dict]:
"""
Convert action array to joint commands with names.
Args:
action: Array of joint actions in standard order
Returns:
List of joint command dictionaries
"""
# Validate action values
action = JointConfig.validate_joint_values(action)
# Create commands with AI names (always use AI names for output)
return JointConfig.create_joint_commands(action)
def reset(self):
"""Reset the inference engine state."""
self.last_images = {}
self.last_joint_positions = None
# Clear any model-specific state
if hasattr(self.policy, "reset"):
self.policy.reset()
async def cleanup(self):
"""Clean up resources."""
if self.policy:
del self.policy
self.policy = None
# Clear GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.is_loaded = False
logger.info(f"Cleaned up inference engine for {self.policy_path}")
def __del__(self):
"""Destructor to ensure cleanup."""
if hasattr(self, "policy") and self.policy:
asyncio.create_task(self.cleanup())