YongdongWang's picture
Update Llama 3.1 8B robot planning space with improvements
8e21b19 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import warnings
import os
warnings.filterwarnings("ignore")
# 模型配置
MODEL_NAME = "meta-llama/Llama-3.1-8B"
LORA_MODEL = "YongdongWang/llama-3.1-8b-dart-qlora"
def load_model():
"""加载模型和分词器"""
print("🔄 Loading model...")
try:
# 4位量化配置
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
use_fast=False,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True
)
# 加载 LoRA 适配器
model = PeftModel.from_pretrained(
base_model,
LORA_MODEL,
torch_dtype=torch.float16
)
model.eval()
print("✅ Model loaded successfully!")
return model, tokenizer
except Exception as load_error:
print(f"❌ Model loading failed: {load_error}")
return None, None
# 全局变量存储模型
model = None
tokenizer = None
model_loading = False
def initialize_model():
"""初始化模型 - 延迟加载"""
global model, tokenizer, model_loading
if model is not None and tokenizer is not None:
return True
if model_loading:
return False
model_loading = True
try:
model, tokenizer = load_model()
return model is not None and tokenizer is not None
finally:
model_loading = False
def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9):
"""生成回复"""
if not initialize_model():
if model_loading:
return "🔄 Model is loading, please wait a few minutes and try again..."
else:
return "❌ Model failed to load. Please check the Space logs or try restarting."
try:
# 格式化输入
formatted_prompt = f"### Human: {prompt.strip()}\n### Assistant:"
# 编码输入
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=2048
).to(model.device)
# 生成回复
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
early_stopping=True,
no_repeat_ngram_size=3
)
# 解码输出
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# 提取生成的部分
if "### Assistant:" in response:
response = response.split("### Assistant:")[-1].strip()
elif len(response) > len(formatted_prompt):
response = response[len(formatted_prompt):].strip()
return response if response else "❌ No response generated. Please try again with a different prompt."
except Exception as generation_error:
return f"❌ Generation Error: {str(generation_error)}"
def chat_interface(message, history, max_tokens, temperature, top_p):
"""聊天界面"""
if not message.strip():
return history, ""
try:
response = generate_response(message, max_tokens, temperature, top_p)
history.append((message, response))
return history, ""
except Exception as chat_error:
error_msg = f"❌ Chat Error: {str(chat_error)}"
history.append((message, error_msg))
return history, ""
# 创建 Gradio 应用
with gr.Blocks(
title="Robot Task Planning - Llama 3.1 8B",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px;
margin: auto;
}
"""
) as demo:
gr.Markdown("""
# 🤖 Llama 3.1 8B - Robot Task Planning
This is a fine-tuned version of Meta's Llama 3.1 8B model specialized for **robot task planning** using QLoRA technique.
**Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots.
**Model**: [YongdongWang/llama-3.1-8b-dart-qlora](https://huggingface.co/YongdongWang/llama-3.1-8b-dart-qlora)
⚠️ **Note**: Model loading may take 3-5 minutes on first startup. Please be patient.
""")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Task Planning Results",
height=500,
show_label=True,
container=True,
bubble_full_width=False,
show_copy_button=True
)
msg = gr.Textbox(
label="Robot Command",
placeholder="Enter robot task command (e.g., 'Deploy Excavator 1 to Soil Area 1')...",
lines=2,
max_lines=5,
show_label=True,
container=True
)
with gr.Row():
send_btn = gr.Button("🚀 Generate Tasks", variant="primary", size="sm")
clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="sm")
with gr.Column(scale=1):
gr.Markdown("### ⚙️ Generation Settings")
max_tokens = gr.Slider(
minimum=50,
maximum=500,
value=200,
step=10,
label="Max Tokens",
info="Maximum number of tokens to generate"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Controls randomness (lower = more focused)"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p",
info="Nucleus sampling threshold"
)
gr.Markdown("""
### 📊 Model Status
The model will load automatically on first use.
Loading time: ~3-5 minutes
""")
# 示例对话
gr.Examples(
examples=['Deploy Excavator 1 to Soil Area 1 for excavation.', 'Send Dump Truck 1 to collect material from Excavator 1, then unload at storage area.', 'Move all robots to avoid Puddle 1 after inspection.', 'Deploy multiple excavators to different soil areas simultaneously.', 'Coordinate dump trucks to transport materials from excavation site to storage.', 'Send robot to inspect rock area, then avoid with all other robots if dangerous.', 'Return all robots to start position after completing tasks.', 'Create a sequence: excavate, load, transport, unload, repeat.'],
inputs=msg,
label="💡 Example Robot Commands"
)
# 事件处理
msg.submit(
chat_interface,
inputs=[msg, chatbot, max_tokens, temperature, top_p],
outputs=[chatbot, msg]
)
send_btn.click(
chat_interface,
inputs=[msg, chatbot, max_tokens, temperature, top_p],
outputs=[chatbot, msg]
)
clear_btn.click(
lambda: ([], ""),
outputs=[chatbot, msg]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)