Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
from peft import PeftModel | |
import warnings | |
import os | |
warnings.filterwarnings("ignore") | |
# 模型配置 | |
MODEL_NAME = "meta-llama/Llama-3.1-8B" | |
LORA_MODEL = "YongdongWang/llama-3.1-8b-dart-qlora" | |
def load_model(): | |
"""加载模型和分词器""" | |
print("🔄 Loading model...") | |
try: | |
# 4位量化配置 | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
) | |
# 加载分词器 | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_NAME, | |
use_fast=False, | |
trust_remote_code=True | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# 加载基础模型 | |
base_model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
quantization_config=bnb_config, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
# 加载 LoRA 适配器 | |
model = PeftModel.from_pretrained( | |
base_model, | |
LORA_MODEL, | |
torch_dtype=torch.float16 | |
) | |
model.eval() | |
print("✅ Model loaded successfully!") | |
return model, tokenizer | |
except Exception as load_error: | |
print(f"❌ Model loading failed: {load_error}") | |
return None, None | |
# 全局变量存储模型 | |
model = None | |
tokenizer = None | |
model_loading = False | |
def initialize_model(): | |
"""初始化模型 - 延迟加载""" | |
global model, tokenizer, model_loading | |
if model is not None and tokenizer is not None: | |
return True | |
if model_loading: | |
return False | |
model_loading = True | |
try: | |
model, tokenizer = load_model() | |
return model is not None and tokenizer is not None | |
finally: | |
model_loading = False | |
def generate_response(prompt, max_tokens=200, temperature=0.7, top_p=0.9): | |
"""生成回复""" | |
if not initialize_model(): | |
if model_loading: | |
return "🔄 Model is loading, please wait a few minutes and try again..." | |
else: | |
return "❌ Model failed to load. Please check the Space logs or try restarting." | |
try: | |
# 格式化输入 | |
formatted_prompt = f"### Human: {prompt.strip()}\n### Assistant:" | |
# 编码输入 | |
inputs = tokenizer( | |
formatted_prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=2048 | |
).to(model.device) | |
# 生成回复 | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
repetition_penalty=1.1, | |
early_stopping=True, | |
no_repeat_ngram_size=3 | |
) | |
# 解码输出 | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# 提取生成的部分 | |
if "### Assistant:" in response: | |
response = response.split("### Assistant:")[-1].strip() | |
elif len(response) > len(formatted_prompt): | |
response = response[len(formatted_prompt):].strip() | |
return response if response else "❌ No response generated. Please try again with a different prompt." | |
except Exception as generation_error: | |
return f"❌ Generation Error: {str(generation_error)}" | |
def chat_interface(message, history, max_tokens, temperature, top_p): | |
"""聊天界面""" | |
if not message.strip(): | |
return history, "" | |
try: | |
response = generate_response(message, max_tokens, temperature, top_p) | |
history.append((message, response)) | |
return history, "" | |
except Exception as chat_error: | |
error_msg = f"❌ Chat Error: {str(chat_error)}" | |
history.append((message, error_msg)) | |
return history, "" | |
# 创建 Gradio 应用 | |
with gr.Blocks( | |
title="Robot Task Planning - Llama 3.1 8B", | |
theme=gr.themes.Soft(), | |
css=""" | |
.gradio-container { | |
max-width: 1200px; | |
margin: auto; | |
} | |
""" | |
) as demo: | |
gr.Markdown(""" | |
# 🤖 Llama 3.1 8B - Robot Task Planning | |
This is a fine-tuned version of Meta's Llama 3.1 8B model specialized for **robot task planning** using QLoRA technique. | |
**Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots. | |
**Model**: [YongdongWang/llama-3.1-8b-dart-qlora](https://huggingface.co/YongdongWang/llama-3.1-8b-dart-qlora) | |
⚠️ **Note**: Model loading may take 3-5 minutes on first startup. Please be patient. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot( | |
label="Task Planning Results", | |
height=500, | |
show_label=True, | |
container=True, | |
bubble_full_width=False, | |
show_copy_button=True | |
) | |
msg = gr.Textbox( | |
label="Robot Command", | |
placeholder="Enter robot task command (e.g., 'Deploy Excavator 1 to Soil Area 1')...", | |
lines=2, | |
max_lines=5, | |
show_label=True, | |
container=True | |
) | |
with gr.Row(): | |
send_btn = gr.Button("🚀 Generate Tasks", variant="primary", size="sm") | |
clear_btn = gr.Button("🗑️ Clear", variant="secondary", size="sm") | |
with gr.Column(scale=1): | |
gr.Markdown("### ⚙️ Generation Settings") | |
max_tokens = gr.Slider( | |
minimum=50, | |
maximum=500, | |
value=200, | |
step=10, | |
label="Max Tokens", | |
info="Maximum number of tokens to generate" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Controls randomness (lower = more focused)" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Top-p", | |
info="Nucleus sampling threshold" | |
) | |
gr.Markdown(""" | |
### 📊 Model Status | |
The model will load automatically on first use. | |
Loading time: ~3-5 minutes | |
""") | |
# 示例对话 | |
gr.Examples( | |
examples=['Deploy Excavator 1 to Soil Area 1 for excavation.', 'Send Dump Truck 1 to collect material from Excavator 1, then unload at storage area.', 'Move all robots to avoid Puddle 1 after inspection.', 'Deploy multiple excavators to different soil areas simultaneously.', 'Coordinate dump trucks to transport materials from excavation site to storage.', 'Send robot to inspect rock area, then avoid with all other robots if dangerous.', 'Return all robots to start position after completing tasks.', 'Create a sequence: excavate, load, transport, unload, repeat.'], | |
inputs=msg, | |
label="💡 Example Robot Commands" | |
) | |
# 事件处理 | |
msg.submit( | |
chat_interface, | |
inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
outputs=[chatbot, msg] | |
) | |
send_btn.click( | |
chat_interface, | |
inputs=[msg, chatbot, max_tokens, temperature, top_p], | |
outputs=[chatbot, msg] | |
) | |
clear_btn.click( | |
lambda: ([], ""), | |
outputs=[chatbot, msg] | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
show_error=True | |
) | |