DART-LLM_Task_Decomposer / gradio_llm_interface.py
YongdongWang's picture
Upload folder using huggingface_hub
92ef79b verified
import rclpy
import gradio as gr
from loguru import logger
from llm_request_handler import LLMRequestHandler
from ros_node_publisher import RosNodePublisher
from json_processor import JsonProcessor
from dag_visualizer import DAGVisualizer
from config import ROBOTS_CONFIG, MODEL_CONFIG, MODE_CONFIG
class GradioLlmInterface:
def __init__(self):
self.node_publisher = None
self.received_tasks = []
self.json_processor = JsonProcessor()
self.dag_visualizer = DAGVisualizer()
def initialize_interface(self, mode):
if not rclpy.ok():
rclpy.init()
self.node_publisher = RosNodePublisher(mode)
mode_config = MODE_CONFIG[mode]
# Use provider from mode_config and model version
model_version = mode_config.get("model_version", MODEL_CONFIG["default_model"])
provider = mode_config.get("provider", "openai")
llm_handler = LLMRequestHandler(
model_name=model_version,
provider=provider,
max_tokens=mode_config.get("max_tokens", MODEL_CONFIG["max_tokens"]),
temperature=mode_config.get("temperature", MODEL_CONFIG["temperature"]),
frequency_penalty=mode_config.get("frequency_penalty", MODEL_CONFIG["frequency_penalty"]),
list_navigation_once=True
)
file_path = mode_config["prompt_file"]
initial_messages = llm_handler.build_initial_messages(file_path, mode)
# Don't create chatbot here, return state data only
state_data = {
"file_path": file_path,
"initial_messages": initial_messages,
"mode": mode,
# store config dict with provider
"llm_config": llm_handler.get_config_dict()
}
return state_data
async def predict(self, input, state):
if not self.node_publisher.is_initialized():
mode = state.get('mode')
self.node_publisher.initialize_node(mode)
initial_messages = state['initial_messages']
full_history = initial_messages + state.get('history', [])
user_input = f"# Query: {input}"
full_history.append({"role": "user", "content": user_input})
mode_config = MODE_CONFIG[state.get('mode')]
if mode_config["type"] == 'complex' and self.received_tasks:
for task in self.received_tasks:
task_prompt = f"# Task: {task}"
full_history.append({"role": "user", "content": task_prompt})
self.received_tasks = []
# Create a new LLMRequestHandler instance for each request
llm_config = state['llm_config']
llm_handler = LLMRequestHandler.create_from_config_dict(llm_config)
response = await llm_handler.make_completion(full_history)
if response:
full_history.append({"role": "assistant", "content": response})
else:
response = "Error: Unable to get response."
full_history.append({"role": "assistant", "content": response})
response_json = self.json_processor.process_response(response)
# Store the task plan for approval workflow
state.update({'pending_task_plan': response_json})
# Generate DAG visualization if valid task data is available
dag_image_path = None
confirm_button_visible = False
if response_json and "tasks" in response_json:
try:
dag_image_path = self.dag_visualizer.create_dag_visualization(
response_json,
title="Robot Task Dependency Graph - Pending Approval"
)
logger.info(f"DAG visualization generated: {dag_image_path}")
confirm_button_visible = True
except Exception as e:
logger.error(f"Failed to generate DAG visualization: {e}")
# Modify the messages format to match the "messages" type
messages = [{"role": message["role"], "content": message["content"]} for message in full_history[len(initial_messages):]]
updated_history = state.get('history', []) + [{"role": "user", "content": input}, {"role": "assistant", "content": response}]
state.update({'history': updated_history})
return messages, state, dag_image_path, gr.update(visible=confirm_button_visible)
def clear_chat(self, state):
state['history'] = []
def show_task_plan_editor(self, state):
"""
Show the current task plan in JSON format for manual editing.
"""
# Check for pending plan first, then deployed plan as fallback
pending_plan = state.get('pending_task_plan')
deployed_plan = state.get('deployed_task_plan')
# Use pending plan if available, otherwise use deployed plan
current_plan = pending_plan if pending_plan else deployed_plan
if current_plan and "tasks" in current_plan and len(current_plan["tasks"]) > 0:
import json
# Format JSON for better readability
formatted_json = json.dumps(current_plan, indent=2, ensure_ascii=False)
plan_status = "pending" if pending_plan else "deployed"
logger.info(f"πŸ“ Task plan editor opened with {plan_status} plan")
# Set pending plan for editing (copy from deployed if needed)
if not pending_plan and deployed_plan:
state.update({'pending_task_plan': deployed_plan})
return (
gr.update(visible=True, value=formatted_json), # Show editor with current JSON
gr.update(visible=True), # Show Update DAG button
gr.update(visible=False), # Hide Validate & Deploy button
f"πŸ“ **Task Plan Editor Opened**\n\nYou can now manually edit the task plan JSON below. {plan_status.title()} plan loaded for editing."
)
else:
# Provide a better template with example structure
template_json = """{
"tasks": [
{
"task": "example_task_1",
"instruction_function": {
"name": "example_function_name",
"robot_ids": ["robot_dump_truck_01"],
"dependencies": [],
"object_keywords": ["object1", "object2"]
}
}
]
}"""
logger.info("πŸ“ Task plan editor opened with template")
return (
gr.update(visible=True, value=template_json), # Show template
gr.update(visible=True), # Show Update DAG button
gr.update(visible=False), # Hide Validate & Deploy button
"⚠️ **No Task Plan Available**\n\nStarting with example template. Please edit the JSON structure and update."
)
def update_dag_from_editor(self, edited_json, state):
"""
Update DAG visualization from manually edited JSON.
"""
try:
import json
# Parse the edited JSON
edited_plan = json.loads(edited_json)
# Validate the JSON structure
if "tasks" not in edited_plan:
raise ValueError("JSON must contain 'tasks' field")
# Store the edited plan
state.update({'pending_task_plan': edited_plan})
# Generate updated DAG visualization
dag_image_path = self.dag_visualizer.create_dag_visualization(
edited_plan,
title="Robot Task Dependency Graph - EDITED & PENDING APPROVAL"
)
logger.info("πŸ”„ DAG updated from manual edits")
return (
dag_image_path,
gr.update(visible=True), # Show Validate & Deploy button
gr.update(visible=False), # Hide editor
gr.update(visible=False), # Hide Update DAG button
"βœ… **DAG Updated Successfully**\n\nTask plan has been updated with your edits. Please review the visualization and click 'Validate & Deploy' to proceed.",
state
)
except json.JSONDecodeError as e:
error_msg = f"❌ **JSON Parsing Error**\n\nInvalid JSON format: {str(e)}\n\nPlease fix the JSON syntax and try again."
return (
None,
gr.update(visible=False), # Hide Validate & Deploy button
gr.update(visible=True), # Keep editor visible
gr.update(visible=True), # Keep Update DAG button visible
error_msg,
state
)
except Exception as e:
error_msg = f"❌ **Update Failed**\n\nError: {str(e)}"
logger.error(f"Failed to update DAG from editor: {e}")
return (
None,
gr.update(visible=False), # Hide Validate & Deploy button
gr.update(visible=True), # Keep editor visible
gr.update(visible=True), # Keep Update DAG button visible
error_msg,
state
)
def validate_and_deploy_task_plan(self, state):
"""
Validate and deploy the task plan to the construction site.
This function implements the safety confirmation workflow.
"""
pending_plan = state.get('pending_task_plan')
if pending_plan:
try:
# Deploy the approved task plan to ROS
self.node_publisher.publish_response(pending_plan)
# Update DAG visualization to show approved status
approved_image_path = None
if "tasks" in pending_plan:
approved_image_path = self.dag_visualizer.create_dag_visualization(
pending_plan,
title="Robot Task Dependency Graph - APPROVED & DEPLOYED"
)
# Keep the deployed plan for potential re-editing, but mark as deployed
state.update({'deployed_task_plan': pending_plan, 'pending_task_plan': None})
logger.info("βœ… Task plan validated and deployed to construction site")
# Return confirmation message and updated visualization
confirmation_msg = "βœ… **Task Plan Successfully Deployed**\n\nThe validated task dependency graph has been sent to the construction site robots. All safety protocols confirmed."
return (
confirmation_msg,
approved_image_path,
gr.update(visible=False), # Hide confirmation button
state
)
except Exception as e:
logger.error(f"Failed to deploy task plan: {e}")
error_msg = f"❌ **Deployment Failed**\n\nError: {str(e)}"
return (
error_msg,
None,
gr.update(visible=True), # Keep button visible for retry
state
)
else:
warning_msg = "⚠️ **No Task Plan to Deploy**\n\nPlease generate a task plan first."
return (
warning_msg,
None,
gr.update(visible=False),
state
)
def update_chatbot(self, mode, state):
# Destroy and reinitialize the ROS node
self.node_publisher.destroy_node()
if not rclpy.ok():
rclpy.init()
self.node_publisher = RosNodePublisher(mode)
self.json_processor = JsonProcessor()
# Update llm_handler with the new model settings
mode_config = MODE_CONFIG[mode]
model_version = mode_config["model_version"]
model_type = mode_config.get("model_type", "openai") # Ensure the correct model_type is used
provider = mode_config.get("provider", MODEL_CONFIG["provider"])
# Re-instantiate LLMRequestHandler with the new model_version and model_type
llm_handler = LLMRequestHandler(
model_version=model_version,
provider=provider,
max_tokens=mode_config.get("max_tokens", MODEL_CONFIG["max_tokens"]),
temperature=mode_config.get("temperature", MODEL_CONFIG["temperature"]),
frequency_penalty=mode_config.get("frequency_penalty", MODEL_CONFIG["frequency_penalty"]),
list_navigation_once=True,
model_type=model_type
)
# Update the prompt file and initial messages
file_path = mode_config["prompt_file"]
initial_messages = llm_handler.build_initial_messages(file_path, mode)
# Update state with the new handler and reset history
logger.info(f"Updating chatbot with {file_path}, model {model_version}, provider {provider}")
state['file_path'] = file_path
state['initial_messages'] = initial_messages
state['history'] = []
state['mode'] = mode
state['llm_config'] = llm_handler.get_config_dict() # Update the state with the new handler
logger.info(f"\033[33mMode updated to {mode}\033[0m")
return gr.update(value=[]), state