Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Legal RAG Chatbot for Hugging Face Spaces | |
| 直接使用原有的chatbot.py,只添加Gradio界面 | |
| """ | |
| import os | |
| import gradio as gr | |
| from chatbot import LegalChatbot | |
| # 配置信息 - 从环境变量获取 | |
| MILVUS_DB_PATH = os.environ.get("MILVUS_DB_PATH", "./milvus_legal_codes.db") | |
| COLLECTION_NAME = os.environ.get("COLLECTION_NAME", "legal_codes_collection") | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
| OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "") | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o") | |
| # 初始化聊天机器人 | |
| def initialize_chatbot(): | |
| if not OPENAI_API_KEY: | |
| raise ValueError("❌ OPENAI_API_KEY environment variable is required!") | |
| print(f"🔑 Using API key: {OPENAI_API_KEY[:10]}...") | |
| print(f"🌐 Base URL: {OPENAI_BASE_URL or 'Default OpenAI'}") | |
| return LegalChatbot( | |
| milvus_db_path=MILVUS_DB_PATH, | |
| collection_name=COLLECTION_NAME, | |
| openai_api_key=OPENAI_API_KEY, | |
| openai_base_url=OPENAI_BASE_URL if OPENAI_BASE_URL else None, | |
| model_name=MODEL_NAME | |
| ) | |
| # 全局聊天机器人实例 | |
| try: | |
| chatbot = initialize_chatbot() | |
| print("✅ Chatbot initialized successfully!") | |
| chatbot_status = f"✅ **Status**: Connected to database with {chatbot.collection_name} collection" | |
| except Exception as e: | |
| print(f"❌ Failed to initialize chatbot: {e}") | |
| chatbot = None | |
| chatbot_status = f"❌ **Status**: Configuration error - {str(e)}" | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| """ | |
| Gradio ChatInterface响应函数 | |
| 使用原有chatbot的流式处理功能 | |
| """ | |
| if chatbot is None: | |
| yield "❌ Chatbot not initialized. Please check the configuration." | |
| return | |
| try: | |
| # 重置聊天机器人的对话历史 | |
| chatbot.reset_conversation() | |
| # 设置系统消息 | |
| if system_message.strip(): | |
| chatbot.conversation_history[0]["content"] = system_message | |
| # 添加历史对话到聊天机器人 | |
| for user_msg, assistant_msg in history: | |
| if user_msg: | |
| chatbot.conversation_history.append({"role": "user", "content": user_msg}) | |
| if assistant_msg: | |
| chatbot.conversation_history.append({"role": "assistant", "content": assistant_msg}) | |
| # 使用原有的流式处理功能 | |
| response = "" | |
| for chunk in chatbot.process_message_stream(message): | |
| if chunk: | |
| response += chunk | |
| yield response | |
| except Exception as e: | |
| print(f"❌ Error in respond: {e}") | |
| yield f"抱歉,处理您的消息时出现错误:{str(e)}" | |
| # 准备描述信息 | |
| base_description = """🤖 **AI法律助手** - 结合向量数据库搜索和大语言模型的智能法律咨询系统 | |
| 🔍 **核心功能:** | |
| - 智能查询分析 - 自动判断是否需要搜索法律数据库 | |
| - 向量相似度搜索 - 基于Milvus的高效法律文档检索 | |
| - RAG增强生成 - 结合搜索结果提供准确回答 | |
| - 实时流式回复 - 支持打字机效果的实时响应 | |
| 💡 **试试这些问题:** | |
| • "What are the fall protection requirements in Ontario construction?" | |
| • "Tell me about employer duties under Canada Labour Code" | |
| • "Search for information about workplace safety regulations" | |
| • "What are my rights under the Charter of Rights and Freedoms?" | |
| """ | |
| full_description = base_description + f"\n\n{chatbot_status}" | |
| # 创建Gradio ChatInterface | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| title="⚖️ Legal RAG Assistant", | |
| description=full_description, | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value="You are a helpful legal assistant with expertise in Canadian law. You have access to a legal database and should provide accurate, well-sourced legal information. Always cite specific legal sources when possible. Remember to include appropriate disclaimers that this is for informational purposes only and not legal advice.", | |
| label="System Message", | |
| lines=3, | |
| max_lines=5 | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=2048, | |
| value=1024, | |
| step=1, | |
| label="Max new tokens" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)" | |
| ), | |
| ], | |
| theme=gr.themes.Soft(), | |
| analytics_enabled=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |