import gradio as gr import os import spaces from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig from threading import Thread import torch import time # Set environment variables HF_TOKEN = os.environ.get("HF_TOKEN", None) # Apollo system prompt SYSTEM_PROMPT = "You are Apollo, a multilingual medical model. You communicate with people and assist them." LICENSE = """
@misc{wang2024apollo, title={Apollo: Lightweight Multilingual Medical LLMs towards Democratizing Medical AI to 6B People}, author={Xidong Wang and Nuo Chen and Junyin Chen and Yan Hu and Yidong Wang and Xiangbo Wu and Anningzhe Gao and Xiang Wan and Haizhou Li and Benyou Wang}, year={2024}, eprint={2403.03640}, archivePrefix={arXiv}, primaryClass={cs.CL} } @misc{zheng2024efficientlydemocratizingmedicalllms, title={Efficiently Democratizing Medical LLMs for 50 Languages via a Mixture of Language Family Experts}, author={Guorui Zheng and Xidong Wang and Juhao Liang and Nuo Chen and Yuping Zheng and Benyou Wang}, year={2024}, eprint={2410.10626}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/2410.10626}, }
""" # Apollo model options APOLLO_MODELS = { "Apollo": [ "FreedomIntelligence/Apollo-7B", "FreedomIntelligence/Apollo-6B", "FreedomIntelligence/Apollo-2B", "FreedomIntelligence/Apollo-0.5B", ], "Apollo2": [ "FreedomIntelligence/Apollo2-7B", "FreedomIntelligence/Apollo2-3.8B", "FreedomIntelligence/Apollo2-2B", ], "Apollo-MoE": [ "FreedomIntelligence/Apollo-MoE-7B", "FreedomIntelligence/Apollo-MoE-1.5B", "FreedomIntelligence/Apollo-MoE-0.5B", ] } # CSS styles css = """ h1 { text-align: center; display: block; } .gradio-container { max-width: 1200px; margin: auto; } """ # Global variables to store currently loaded model and tokenizer current_model = None current_tokenizer = None current_model_path = None @spaces.GPU(duration=120) def load_model(model_path, progress=gr.Progress()): """Load the selected model and tokenizer""" global current_model, current_tokenizer, current_model_path # If the same model is already loaded, don't reload it if current_model_path == model_path and current_model is not None: return "Model already loaded, no need to reload." # Clean up previously loaded model (if any) if current_model is not None: del current_model del current_tokenizer torch.cuda.empty_cache() progress(0.1, desc=f"Starting to load model {model_path}...") try: progress(0.3, desc="Loading tokenizer...") config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) if 'MoE' in model_path: from configuration_upcycling_qwen2_moe import UpcyclingQwen2MoeConfig config = UpcyclingQwen2MoeConfig.from_pretrained(model_path, trust_remote_code=True) # config_moe.auto_map["AutoConfig"] = "./configuration_upcycling_qwen2_moe.UpcyclingQwen2MoeConfig" # config_moe.auto_map["AutoModelForCausalLM"] = "./modeling_upcycling_qwen2_moe.UpcyclingQwen2MoeForCausalLM" current_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False,trust_remote_code=True) progress(0.5, desc="Loading model...") if 'MoE' in model_path: from modeling_upcycling_qwen2_moe import UpcyclingQwen2MoeForCausalLM current_model = UpcyclingQwen2MoeForCausalLM.from_pretrained( model_path, device_map="auto", torch_dtype=torch.float16, config=config, trust_remote_code=True ) else: current_model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", torch_dtype=torch.float16, config=config, trust_remote_code=True ) current_model_path = model_path progress(1.0, desc="Model loading complete!") return f"Model {model_path} successfully loaded." except Exception as e: progress(1.0, desc="Model loading failed!") return f"Model loading failed: {str(e)}" @spaces.GPU(duration=120) def generate_response_non_streaming(instruction, model_name, temperature=0.7, max_tokens=1024): """Generate a response from the Apollo model (non-streaming)""" global current_model, current_tokenizer, current_model_path print("instruction:",instruction) # If model is not yet loaded, load it first if current_model_path != model_name or current_model is None: load_message = load_model(model_name) if "failed" in load_message.lower(): return load_message try: # 直接使用简单的提示格式,不使用模型的聊天模板 prompt = f"User:{instruction}\nAssistant:" print("prompt:",prompt) chat_input = current_tokenizer.encode(prompt, return_tensors="pt").to(current_model.device) # 生成响应 output = current_model.generate( input_ids=chat_input, max_new_tokens=max_tokens, temperature=temperature, do_sample=(temperature > 0), eos_token_id=current_tokenizer.eos_token_id # 使用<|endoftext|>作为停止标记 ) # 解码并返回生成的文本 generated_text = current_tokenizer.decode(output[0][len(chat_input[0]):], skip_special_tokens=True) print("generated_text:",generated_text) return generated_text except Exception as e: return f"生成响应时出错: {str(e)}" # try: # # 检查模型是否有聊天模板 # if hasattr(current_tokenizer, 'chat_template') and current_tokenizer.chat_template: # # 使用模型的聊天模板 # messages = [ # {"role": "system", "content": SYSTEM_PROMPT}, # {"role": "user", "content": instruction} # ] # # 使用模型的聊天模板格式化输入 # chat_input = current_tokenizer.apply_chat_template( # messages, # tokenize=True, # return_tensors="pt" # ).to(current_model.device) # else: # # 使用指定的提示格式 # prompt = f"User:{instruction}\nAssistant:" # chat_input = current_tokenizer.encode(prompt, return_tensors="pt").to(current_model.device) # # 获取<|endoftext|>的token id,用于停止生成 # eos_token_id = current_tokenizer.eos_token_id # # 生成响应 # output = current_model.generate( # input_ids=chat_input, # max_new_tokens=max_tokens, # temperature=temperature, # do_sample=(temperature > 0), # eos_token_id=current_tokenizer.eos_token_id # 使用<|endoftext|>作为停止标记 # ) # # 解码并返回生成的文本 # generated_text = current_tokenizer.decode(output[0][len(chat_input[0]):], skip_special_tokens=True) # return generated_text # except Exception as e: # return f"生成响应时出错: {str(e)}" def update_chat_with_response(chatbot, instruction, model_name, temperature, max_tokens): """Updates the chatbot with non-streaming response""" global current_model, current_tokenizer, current_model_path # If model is not yet loaded, load it first if current_model_path != model_name or current_model is None: load_result = load_model(model_name) if "failed" in load_result.lower(): new_chat = list(chatbot) new_chat[-1] = (instruction, load_result) return new_chat # Generate response using the non-streaming function response = generate_response_non_streaming(instruction, model_name, temperature, max_tokens) # Create a copy of the current chatbot and add the response new_chat = list(chatbot) new_chat[-1] = (instruction, response) return new_chat def on_model_series_change(model_series): """Update available model list based on selected model series""" if model_series in APOLLO_MODELS: return gr.update(choices=APOLLO_MODELS[model_series], value=APOLLO_MODELS[model_series][0]) return gr.update(choices=[], value=None) def process_message(message, chat_history, model_series_value, model_name_value, temperature_value, max_tokens_value): """Process user message and generate response""" if message.strip() == "": return "", chat_history # 打印用户提交的消息,用于调试 print("instruction:", message) # Add user message to chat history chat_history = list(chat_history) chat_history.append((message, None)) # 自动加载模型(如果需要) global current_model, current_tokenizer, current_model_path if current_model_path != model_name_value or current_model is None: try: load_result = load_model(model_name_value) if "failed" in load_result.lower(): chat_history[-1] = (message, f"模型加载失败: {load_result}") return "", chat_history except Exception as e: chat_history[-1] = (message, f"模型加载出错: {str(e)}") return "", chat_history # Generate response try: response = generate_response_non_streaming(message, model_name_value, temperature_value, max_tokens_value) # Add response to chat history chat_history[-1] = (message, response) except Exception as e: chat_history[-1] = (message, f"生成响应时出错: {str(e)}") return "", chat_history # Create Gradio interface with gr.Blocks(css=css) as demo: # Title and description favicon = "🩺" gr.Markdown( f"""# {favicon} Apollo Playground This is a demo of the multilingual medical model series **[Apollo](https://github.com/FreedomIntelligence/Apollo)** made by **[FreedomIntelligence](https://huggingface.co/FreedomIntelligence)**. [Apollo1](https://arxiv.org/abs/2403.03640) supports 6 languages. [Apollo2](https://arxiv.org/abs/2410.10626) and [Apollo-MoE](https://arxiv.org/abs/2410.10626) supports 50 languages. """ ) with gr.Row(): with gr.Column(scale=1): # Model selection controls model_series = gr.Dropdown( choices=list(APOLLO_MODELS.keys()), value="Apollo", label="Select Model Series", info="First choose Apollo, Apollo2 or Apollo-MoE" ) model_name = gr.Dropdown( choices=APOLLO_MODELS["Apollo"], value=APOLLO_MODELS["Apollo"][0], label="Select Model Size", info="Select the specific model size based on the chosen model series" ) # Parameter settings with gr.Accordion("Generation Parameters", open=False): temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.7, step=0.05, label="Temperature" ) max_tokens = gr.Slider( minimum=128, maximum=2048, value=1024, step=32, label="Maximum Tokens" ) # 移除Load Model按钮和状态显示 # load_button = gr.Button("Load Model") # model_status = gr.Textbox(label="Model Status", value="No model loaded yet") with gr.Column(scale=2): # Chat interface chatbot = gr.Chatbot(label="Conversation", height=500, value=[]) # Initialize with empty list user_input = gr.Textbox( label="Input Medical Question", placeholder="Example: What are the symptoms of hypertension? 高血压有哪些症状?", lines=3 ) submit_button = gr.Button("Submit") clear_button = gr.Button("Clear Chat") # Event handling # Update model selection when model series changes model_series.change( fn=on_model_series_change, inputs=model_series, outputs=model_name ) # 修改提交事件绑定 submit_event = user_input.submit( fn=process_message, inputs=[user_input, chatbot, model_series, model_name, temperature, max_tokens], outputs=[user_input, chatbot] ) submit_button.click( fn=process_message, inputs=[user_input, chatbot, model_series, model_name, temperature, max_tokens], outputs=[user_input, chatbot] ) # Clear chat clear_button.click( fn=lambda: [], outputs=chatbot ) # # Handle message submission # def user_message_submitted(message, chat_history): # """Handle user submitted message""" # # Ensure chat_history is a list # if chat_history is None: # chat_history = [] # if message.strip() == "": # return "", chat_history # # Add user message to chat history # chat_history = list(chat_history) # chat_history.append((message, None)) # return "", chat_history # # Bind message submission # submit_event = user_input.submit( # fn=user_message_submitted, # inputs=[user_input, chatbot], # outputs=[user_input, chatbot] # ).then( # fn=update_chat_with_response, # inputs=[chatbot, user_input, model_name, temperature, max_tokens], # outputs=chatbot # ) # submit_button.click( # fn=user_message_submitted, # inputs=[user_input, chatbot], # outputs=[user_input, chatbot] # ).then( # fn=update_chat_with_response, # inputs=[chatbot, user_input, model_name, temperature, max_tokens], # outputs=chatbot # ) # # Clear chat # clear_button.click( # fn=lambda: [], # outputs=chatbot # ) examples = [ ["Últimamente tengo la tensión un poco alta, ¿cómo debo adaptar mis hábitos?"], ["What are the common side effects of metformin?"], ["中医和西医在治疗高血压方面有什么不同的观点?"], ["मेरा सिर दर्द कर रहा है, मुझे क्या करना चाहिए? "], ["Comment savoir si je suis diabétique ?"], ["ما الدواء الذي يمكنني تناوله إذا لم أستطع النوم ليلاً؟"], ["针对一名28岁女性患者,她左小腿挫伤12小时,伤口有分泌物,骨折端外露,小腿成角畸形,描述她的最佳处理方法。"] ] gr.Examples( examples=examples, inputs=user_input ) gr.HTML(LICENSE) if __name__ == "__main__": demo.launch()