Spaces:
No application file
No application file
File size: 13,643 Bytes
92ef79b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 |
import rclpy
import gradio as gr
from loguru import logger
from llm_request_handler import LLMRequestHandler
from ros_node_publisher import RosNodePublisher
from json_processor import JsonProcessor
from dag_visualizer import DAGVisualizer
from config import ROBOTS_CONFIG, MODEL_CONFIG, MODE_CONFIG
class GradioLlmInterface:
def __init__(self):
self.node_publisher = None
self.received_tasks = []
self.json_processor = JsonProcessor()
self.dag_visualizer = DAGVisualizer()
def initialize_interface(self, mode):
if not rclpy.ok():
rclpy.init()
self.node_publisher = RosNodePublisher(mode)
mode_config = MODE_CONFIG[mode]
# Use provider from mode_config and model version
model_version = mode_config.get("model_version", MODEL_CONFIG["default_model"])
provider = mode_config.get("provider", "openai")
llm_handler = LLMRequestHandler(
model_name=model_version,
provider=provider,
max_tokens=mode_config.get("max_tokens", MODEL_CONFIG["max_tokens"]),
temperature=mode_config.get("temperature", MODEL_CONFIG["temperature"]),
frequency_penalty=mode_config.get("frequency_penalty", MODEL_CONFIG["frequency_penalty"]),
list_navigation_once=True
)
file_path = mode_config["prompt_file"]
initial_messages = llm_handler.build_initial_messages(file_path, mode)
# Don't create chatbot here, return state data only
state_data = {
"file_path": file_path,
"initial_messages": initial_messages,
"mode": mode,
# store config dict with provider
"llm_config": llm_handler.get_config_dict()
}
return state_data
async def predict(self, input, state):
if not self.node_publisher.is_initialized():
mode = state.get('mode')
self.node_publisher.initialize_node(mode)
initial_messages = state['initial_messages']
full_history = initial_messages + state.get('history', [])
user_input = f"# Query: {input}"
full_history.append({"role": "user", "content": user_input})
mode_config = MODE_CONFIG[state.get('mode')]
if mode_config["type"] == 'complex' and self.received_tasks:
for task in self.received_tasks:
task_prompt = f"# Task: {task}"
full_history.append({"role": "user", "content": task_prompt})
self.received_tasks = []
# Create a new LLMRequestHandler instance for each request
llm_config = state['llm_config']
llm_handler = LLMRequestHandler.create_from_config_dict(llm_config)
response = await llm_handler.make_completion(full_history)
if response:
full_history.append({"role": "assistant", "content": response})
else:
response = "Error: Unable to get response."
full_history.append({"role": "assistant", "content": response})
response_json = self.json_processor.process_response(response)
# Store the task plan for approval workflow
state.update({'pending_task_plan': response_json})
# Generate DAG visualization if valid task data is available
dag_image_path = None
confirm_button_visible = False
if response_json and "tasks" in response_json:
try:
dag_image_path = self.dag_visualizer.create_dag_visualization(
response_json,
title="Robot Task Dependency Graph - Pending Approval"
)
logger.info(f"DAG visualization generated: {dag_image_path}")
confirm_button_visible = True
except Exception as e:
logger.error(f"Failed to generate DAG visualization: {e}")
# Modify the messages format to match the "messages" type
messages = [{"role": message["role"], "content": message["content"]} for message in full_history[len(initial_messages):]]
updated_history = state.get('history', []) + [{"role": "user", "content": input}, {"role": "assistant", "content": response}]
state.update({'history': updated_history})
return messages, state, dag_image_path, gr.update(visible=confirm_button_visible)
def clear_chat(self, state):
state['history'] = []
def show_task_plan_editor(self, state):
"""
Show the current task plan in JSON format for manual editing.
"""
# Check for pending plan first, then deployed plan as fallback
pending_plan = state.get('pending_task_plan')
deployed_plan = state.get('deployed_task_plan')
# Use pending plan if available, otherwise use deployed plan
current_plan = pending_plan if pending_plan else deployed_plan
if current_plan and "tasks" in current_plan and len(current_plan["tasks"]) > 0:
import json
# Format JSON for better readability
formatted_json = json.dumps(current_plan, indent=2, ensure_ascii=False)
plan_status = "pending" if pending_plan else "deployed"
logger.info(f"π Task plan editor opened with {plan_status} plan")
# Set pending plan for editing (copy from deployed if needed)
if not pending_plan and deployed_plan:
state.update({'pending_task_plan': deployed_plan})
return (
gr.update(visible=True, value=formatted_json), # Show editor with current JSON
gr.update(visible=True), # Show Update DAG button
gr.update(visible=False), # Hide Validate & Deploy button
f"π **Task Plan Editor Opened**\n\nYou can now manually edit the task plan JSON below. {plan_status.title()} plan loaded for editing."
)
else:
# Provide a better template with example structure
template_json = """{
"tasks": [
{
"task": "example_task_1",
"instruction_function": {
"name": "example_function_name",
"robot_ids": ["robot_dump_truck_01"],
"dependencies": [],
"object_keywords": ["object1", "object2"]
}
}
]
}"""
logger.info("π Task plan editor opened with template")
return (
gr.update(visible=True, value=template_json), # Show template
gr.update(visible=True), # Show Update DAG button
gr.update(visible=False), # Hide Validate & Deploy button
"β οΈ **No Task Plan Available**\n\nStarting with example template. Please edit the JSON structure and update."
)
def update_dag_from_editor(self, edited_json, state):
"""
Update DAG visualization from manually edited JSON.
"""
try:
import json
# Parse the edited JSON
edited_plan = json.loads(edited_json)
# Validate the JSON structure
if "tasks" not in edited_plan:
raise ValueError("JSON must contain 'tasks' field")
# Store the edited plan
state.update({'pending_task_plan': edited_plan})
# Generate updated DAG visualization
dag_image_path = self.dag_visualizer.create_dag_visualization(
edited_plan,
title="Robot Task Dependency Graph - EDITED & PENDING APPROVAL"
)
logger.info("π DAG updated from manual edits")
return (
dag_image_path,
gr.update(visible=True), # Show Validate & Deploy button
gr.update(visible=False), # Hide editor
gr.update(visible=False), # Hide Update DAG button
"β
**DAG Updated Successfully**\n\nTask plan has been updated with your edits. Please review the visualization and click 'Validate & Deploy' to proceed.",
state
)
except json.JSONDecodeError as e:
error_msg = f"β **JSON Parsing Error**\n\nInvalid JSON format: {str(e)}\n\nPlease fix the JSON syntax and try again."
return (
None,
gr.update(visible=False), # Hide Validate & Deploy button
gr.update(visible=True), # Keep editor visible
gr.update(visible=True), # Keep Update DAG button visible
error_msg,
state
)
except Exception as e:
error_msg = f"β **Update Failed**\n\nError: {str(e)}"
logger.error(f"Failed to update DAG from editor: {e}")
return (
None,
gr.update(visible=False), # Hide Validate & Deploy button
gr.update(visible=True), # Keep editor visible
gr.update(visible=True), # Keep Update DAG button visible
error_msg,
state
)
def validate_and_deploy_task_plan(self, state):
"""
Validate and deploy the task plan to the construction site.
This function implements the safety confirmation workflow.
"""
pending_plan = state.get('pending_task_plan')
if pending_plan:
try:
# Deploy the approved task plan to ROS
self.node_publisher.publish_response(pending_plan)
# Update DAG visualization to show approved status
approved_image_path = None
if "tasks" in pending_plan:
approved_image_path = self.dag_visualizer.create_dag_visualization(
pending_plan,
title="Robot Task Dependency Graph - APPROVED & DEPLOYED"
)
# Keep the deployed plan for potential re-editing, but mark as deployed
state.update({'deployed_task_plan': pending_plan, 'pending_task_plan': None})
logger.info("β
Task plan validated and deployed to construction site")
# Return confirmation message and updated visualization
confirmation_msg = "β
**Task Plan Successfully Deployed**\n\nThe validated task dependency graph has been sent to the construction site robots. All safety protocols confirmed."
return (
confirmation_msg,
approved_image_path,
gr.update(visible=False), # Hide confirmation button
state
)
except Exception as e:
logger.error(f"Failed to deploy task plan: {e}")
error_msg = f"β **Deployment Failed**\n\nError: {str(e)}"
return (
error_msg,
None,
gr.update(visible=True), # Keep button visible for retry
state
)
else:
warning_msg = "β οΈ **No Task Plan to Deploy**\n\nPlease generate a task plan first."
return (
warning_msg,
None,
gr.update(visible=False),
state
)
def update_chatbot(self, mode, state):
# Destroy and reinitialize the ROS node
self.node_publisher.destroy_node()
if not rclpy.ok():
rclpy.init()
self.node_publisher = RosNodePublisher(mode)
self.json_processor = JsonProcessor()
# Update llm_handler with the new model settings
mode_config = MODE_CONFIG[mode]
model_version = mode_config["model_version"]
model_type = mode_config.get("model_type", "openai") # Ensure the correct model_type is used
provider = mode_config.get("provider", MODEL_CONFIG["provider"])
# Re-instantiate LLMRequestHandler with the new model_version and model_type
llm_handler = LLMRequestHandler(
model_version=model_version,
provider=provider,
max_tokens=mode_config.get("max_tokens", MODEL_CONFIG["max_tokens"]),
temperature=mode_config.get("temperature", MODEL_CONFIG["temperature"]),
frequency_penalty=mode_config.get("frequency_penalty", MODEL_CONFIG["frequency_penalty"]),
list_navigation_once=True,
model_type=model_type
)
# Update the prompt file and initial messages
file_path = mode_config["prompt_file"]
initial_messages = llm_handler.build_initial_messages(file_path, mode)
# Update state with the new handler and reset history
logger.info(f"Updating chatbot with {file_path}, model {model_version}, provider {provider}")
state['file_path'] = file_path
state['initial_messages'] = initial_messages
state['history'] = []
state['mode'] = mode
state['llm_config'] = llm_handler.get_config_dict() # Update the state with the new handler
logger.info(f"\033[33mMode updated to {mode}\033[0m")
return gr.update(value=[]), state
|