Spaces:
No application file
No application file
import rclpy | |
import gradio as gr | |
from loguru import logger | |
from llm_request_handler import LLMRequestHandler | |
from ros_node_publisher import RosNodePublisher | |
from json_processor import JsonProcessor | |
from dag_visualizer import DAGVisualizer | |
from config import ROBOTS_CONFIG, MODEL_CONFIG, MODE_CONFIG | |
class GradioLlmInterface: | |
def __init__(self): | |
self.node_publisher = None | |
self.received_tasks = [] | |
self.json_processor = JsonProcessor() | |
self.dag_visualizer = DAGVisualizer() | |
def initialize_interface(self, mode): | |
if not rclpy.ok(): | |
rclpy.init() | |
self.node_publisher = RosNodePublisher(mode) | |
mode_config = MODE_CONFIG[mode] | |
# Use provider from mode_config and model version | |
model_version = mode_config.get("model_version", MODEL_CONFIG["default_model"]) | |
provider = mode_config.get("provider", "openai") | |
llm_handler = LLMRequestHandler( | |
model_name=model_version, | |
provider=provider, | |
max_tokens=mode_config.get("max_tokens", MODEL_CONFIG["max_tokens"]), | |
temperature=mode_config.get("temperature", MODEL_CONFIG["temperature"]), | |
frequency_penalty=mode_config.get("frequency_penalty", MODEL_CONFIG["frequency_penalty"]), | |
list_navigation_once=True | |
) | |
file_path = mode_config["prompt_file"] | |
initial_messages = llm_handler.build_initial_messages(file_path, mode) | |
# Don't create chatbot here, return state data only | |
state_data = { | |
"file_path": file_path, | |
"initial_messages": initial_messages, | |
"mode": mode, | |
# store config dict with provider | |
"llm_config": llm_handler.get_config_dict() | |
} | |
return state_data | |
async def predict(self, input, state): | |
if not self.node_publisher.is_initialized(): | |
mode = state.get('mode') | |
self.node_publisher.initialize_node(mode) | |
initial_messages = state['initial_messages'] | |
full_history = initial_messages + state.get('history', []) | |
user_input = f"# Query: {input}" | |
full_history.append({"role": "user", "content": user_input}) | |
mode_config = MODE_CONFIG[state.get('mode')] | |
if mode_config["type"] == 'complex' and self.received_tasks: | |
for task in self.received_tasks: | |
task_prompt = f"# Task: {task}" | |
full_history.append({"role": "user", "content": task_prompt}) | |
self.received_tasks = [] | |
# Create a new LLMRequestHandler instance for each request | |
llm_config = state['llm_config'] | |
llm_handler = LLMRequestHandler.create_from_config_dict(llm_config) | |
response = await llm_handler.make_completion(full_history) | |
if response: | |
full_history.append({"role": "assistant", "content": response}) | |
else: | |
response = "Error: Unable to get response." | |
full_history.append({"role": "assistant", "content": response}) | |
response_json = self.json_processor.process_response(response) | |
# Store the task plan for approval workflow | |
state.update({'pending_task_plan': response_json}) | |
# Generate DAG visualization if valid task data is available | |
dag_image_path = None | |
confirm_button_visible = False | |
if response_json and "tasks" in response_json: | |
try: | |
dag_image_path = self.dag_visualizer.create_dag_visualization( | |
response_json, | |
title="Robot Task Dependency Graph - Pending Approval" | |
) | |
logger.info(f"DAG visualization generated: {dag_image_path}") | |
confirm_button_visible = True | |
except Exception as e: | |
logger.error(f"Failed to generate DAG visualization: {e}") | |
# Modify the messages format to match the "messages" type | |
messages = [{"role": message["role"], "content": message["content"]} for message in full_history[len(initial_messages):]] | |
updated_history = state.get('history', []) + [{"role": "user", "content": input}, {"role": "assistant", "content": response}] | |
state.update({'history': updated_history}) | |
return messages, state, dag_image_path, gr.update(visible=confirm_button_visible) | |
def clear_chat(self, state): | |
state['history'] = [] | |
def show_task_plan_editor(self, state): | |
""" | |
Show the current task plan in JSON format for manual editing. | |
""" | |
# Check for pending plan first, then deployed plan as fallback | |
pending_plan = state.get('pending_task_plan') | |
deployed_plan = state.get('deployed_task_plan') | |
# Use pending plan if available, otherwise use deployed plan | |
current_plan = pending_plan if pending_plan else deployed_plan | |
if current_plan and "tasks" in current_plan and len(current_plan["tasks"]) > 0: | |
import json | |
# Format JSON for better readability | |
formatted_json = json.dumps(current_plan, indent=2, ensure_ascii=False) | |
plan_status = "pending" if pending_plan else "deployed" | |
logger.info(f"π Task plan editor opened with {plan_status} plan") | |
# Set pending plan for editing (copy from deployed if needed) | |
if not pending_plan and deployed_plan: | |
state.update({'pending_task_plan': deployed_plan}) | |
return ( | |
gr.update(visible=True, value=formatted_json), # Show editor with current JSON | |
gr.update(visible=True), # Show Update DAG button | |
gr.update(visible=False), # Hide Validate & Deploy button | |
f"π **Task Plan Editor Opened**\n\nYou can now manually edit the task plan JSON below. {plan_status.title()} plan loaded for editing." | |
) | |
else: | |
# Provide a better template with example structure | |
template_json = """{ | |
"tasks": [ | |
{ | |
"task": "example_task_1", | |
"instruction_function": { | |
"name": "example_function_name", | |
"robot_ids": ["robot_dump_truck_01"], | |
"dependencies": [], | |
"object_keywords": ["object1", "object2"] | |
} | |
} | |
] | |
}""" | |
logger.info("π Task plan editor opened with template") | |
return ( | |
gr.update(visible=True, value=template_json), # Show template | |
gr.update(visible=True), # Show Update DAG button | |
gr.update(visible=False), # Hide Validate & Deploy button | |
"β οΈ **No Task Plan Available**\n\nStarting with example template. Please edit the JSON structure and update." | |
) | |
def update_dag_from_editor(self, edited_json, state): | |
""" | |
Update DAG visualization from manually edited JSON. | |
""" | |
try: | |
import json | |
# Parse the edited JSON | |
edited_plan = json.loads(edited_json) | |
# Validate the JSON structure | |
if "tasks" not in edited_plan: | |
raise ValueError("JSON must contain 'tasks' field") | |
# Store the edited plan | |
state.update({'pending_task_plan': edited_plan}) | |
# Generate updated DAG visualization | |
dag_image_path = self.dag_visualizer.create_dag_visualization( | |
edited_plan, | |
title="Robot Task Dependency Graph - EDITED & PENDING APPROVAL" | |
) | |
logger.info("π DAG updated from manual edits") | |
return ( | |
dag_image_path, | |
gr.update(visible=True), # Show Validate & Deploy button | |
gr.update(visible=False), # Hide editor | |
gr.update(visible=False), # Hide Update DAG button | |
"β **DAG Updated Successfully**\n\nTask plan has been updated with your edits. Please review the visualization and click 'Validate & Deploy' to proceed.", | |
state | |
) | |
except json.JSONDecodeError as e: | |
error_msg = f"β **JSON Parsing Error**\n\nInvalid JSON format: {str(e)}\n\nPlease fix the JSON syntax and try again." | |
return ( | |
None, | |
gr.update(visible=False), # Hide Validate & Deploy button | |
gr.update(visible=True), # Keep editor visible | |
gr.update(visible=True), # Keep Update DAG button visible | |
error_msg, | |
state | |
) | |
except Exception as e: | |
error_msg = f"β **Update Failed**\n\nError: {str(e)}" | |
logger.error(f"Failed to update DAG from editor: {e}") | |
return ( | |
None, | |
gr.update(visible=False), # Hide Validate & Deploy button | |
gr.update(visible=True), # Keep editor visible | |
gr.update(visible=True), # Keep Update DAG button visible | |
error_msg, | |
state | |
) | |
def validate_and_deploy_task_plan(self, state): | |
""" | |
Validate and deploy the task plan to the construction site. | |
This function implements the safety confirmation workflow. | |
""" | |
pending_plan = state.get('pending_task_plan') | |
if pending_plan: | |
try: | |
# Deploy the approved task plan to ROS | |
self.node_publisher.publish_response(pending_plan) | |
# Update DAG visualization to show approved status | |
approved_image_path = None | |
if "tasks" in pending_plan: | |
approved_image_path = self.dag_visualizer.create_dag_visualization( | |
pending_plan, | |
title="Robot Task Dependency Graph - APPROVED & DEPLOYED" | |
) | |
# Keep the deployed plan for potential re-editing, but mark as deployed | |
state.update({'deployed_task_plan': pending_plan, 'pending_task_plan': None}) | |
logger.info("β Task plan validated and deployed to construction site") | |
# Return confirmation message and updated visualization | |
confirmation_msg = "β **Task Plan Successfully Deployed**\n\nThe validated task dependency graph has been sent to the construction site robots. All safety protocols confirmed." | |
return ( | |
confirmation_msg, | |
approved_image_path, | |
gr.update(visible=False), # Hide confirmation button | |
state | |
) | |
except Exception as e: | |
logger.error(f"Failed to deploy task plan: {e}") | |
error_msg = f"β **Deployment Failed**\n\nError: {str(e)}" | |
return ( | |
error_msg, | |
None, | |
gr.update(visible=True), # Keep button visible for retry | |
state | |
) | |
else: | |
warning_msg = "β οΈ **No Task Plan to Deploy**\n\nPlease generate a task plan first." | |
return ( | |
warning_msg, | |
None, | |
gr.update(visible=False), | |
state | |
) | |
def update_chatbot(self, mode, state): | |
# Destroy and reinitialize the ROS node | |
self.node_publisher.destroy_node() | |
if not rclpy.ok(): | |
rclpy.init() | |
self.node_publisher = RosNodePublisher(mode) | |
self.json_processor = JsonProcessor() | |
# Update llm_handler with the new model settings | |
mode_config = MODE_CONFIG[mode] | |
model_version = mode_config["model_version"] | |
model_type = mode_config.get("model_type", "openai") # Ensure the correct model_type is used | |
provider = mode_config.get("provider", MODEL_CONFIG["provider"]) | |
# Re-instantiate LLMRequestHandler with the new model_version and model_type | |
llm_handler = LLMRequestHandler( | |
model_version=model_version, | |
provider=provider, | |
max_tokens=mode_config.get("max_tokens", MODEL_CONFIG["max_tokens"]), | |
temperature=mode_config.get("temperature", MODEL_CONFIG["temperature"]), | |
frequency_penalty=mode_config.get("frequency_penalty", MODEL_CONFIG["frequency_penalty"]), | |
list_navigation_once=True, | |
model_type=model_type | |
) | |
# Update the prompt file and initial messages | |
file_path = mode_config["prompt_file"] | |
initial_messages = llm_handler.build_initial_messages(file_path, mode) | |
# Update state with the new handler and reset history | |
logger.info(f"Updating chatbot with {file_path}, model {model_version}, provider {provider}") | |
state['file_path'] = file_path | |
state['initial_messages'] = initial_messages | |
state['history'] = [] | |
state['mode'] = mode | |
state['llm_config'] = llm_handler.get_config_dict() # Update the state with the new handler | |
logger.info(f"\033[33mMode updated to {mode}\033[0m") | |
return gr.update(value=[]), state | |