Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig | |
| import logging | |
| import gc | |
| import warnings | |
| import os | |
| from huggingface_hub import login | |
| from config import MODEL_CONFIGS, DEFAULT_MODEL, MODEL_SETTINGS, GENERATION_DEFAULTS, MEDICAL_SYSTEM_PROMPT, UI_CONFIG | |
| # Login with the secret token | |
| login(token=os.getenv("HF_TOKEN")) | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| # Global variables for model and tokenizer | |
| model = None | |
| tokenizer = None | |
| current_model_name = None | |
| def load_model(model_key=None): | |
| """Load the specified medical model with optimizations for Hugging Face Spaces""" | |
| global model, tokenizer, current_model_name | |
| if model_key is None: | |
| model_key = DEFAULT_MODEL | |
| # Try to load models in order of preference - prioritize lightweight models | |
| model_keys_to_try = [model_key, "flan_t5_small", "dialogpt_medium", "meditron"] | |
| for key in model_keys_to_try: | |
| if key not in MODEL_CONFIGS: | |
| continue | |
| try: | |
| model_config = MODEL_CONFIGS[key] | |
| model_name = model_config["name"] | |
| print(f"Attempting to load model: {model_name} ({model_config['description']})") | |
| # Load tokenizer first | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=MODEL_SETTINGS["trust_remote_code"], | |
| padding_side="left" | |
| ) | |
| # Add pad token if it doesn't exist | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Configure quantization for memory efficiency (only for larger models) | |
| model_kwargs = { | |
| "trust_remote_code": MODEL_SETTINGS["trust_remote_code"], | |
| "low_cpu_mem_usage": MODEL_SETTINGS["low_cpu_mem_usage"] | |
| } | |
| # Optimized loading for CPU performance | |
| if MODEL_SETTINGS["use_quantization"] and torch.cuda.is_available() and key in ["medllama2", "meditron", "clinical_camel"]: | |
| # Only use quantization on GPU for larger models | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model_kwargs["quantization_config"] = quantization_config | |
| model_kwargs["torch_dtype"] = torch.float16 | |
| model_kwargs["device_map"] = MODEL_SETTINGS["device_map"] | |
| else: | |
| # For CPU or smaller models, use optimized settings | |
| if torch.cuda.is_available(): | |
| model_kwargs["torch_dtype"] = torch.float16 | |
| model_kwargs["device_map"] = "auto" | |
| else: | |
| # CPU-optimized settings | |
| model_kwargs["torch_dtype"] = torch.float32 # Use float32 on CPU | |
| model_kwargs["device_map"] = None # Let it use CPU naturally | |
| print("Loading model...") | |
| # Use appropriate model class based on model type | |
| if "flan-t5" in model_name.lower() or "t5" in model_name.lower(): | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) | |
| current_model_name = model_name | |
| print(f"β Model loaded successfully: {model_name}") | |
| return True | |
| except Exception as e: | |
| print(f"β Failed to load {key}: {str(e)}") | |
| # Clean up on failure | |
| model = None | |
| tokenizer = None | |
| continue | |
| print("β All model loading attempts failed") | |
| return False | |
| def generate_response(prompt, max_tokens=None, temperature=None, top_p=None): | |
| """Generate response using the loaded model""" | |
| global model, tokenizer, current_model_name | |
| print(f"Starting generation for prompt: {prompt}") | |
| if not prompt or not prompt.strip(): | |
| return "Please enter a question. π" | |
| if model is None or tokenizer is None: | |
| return "β Model not loaded. Please wait for initialization or try restarting the space." | |
| # Use defaults if not specified | |
| max_tokens = max_tokens or GENERATION_DEFAULTS["max_new_tokens"] | |
| temperature = temperature or GENERATION_DEFAULTS["temperature"] | |
| top_p = top_p or GENERATION_DEFAULTS["top_p"] | |
| try: | |
| # Format prompt based on model type | |
| if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower(): | |
| # Use a concise instruction prefix for T5 | |
| instruction = "You are a friendly medical assistant. Answer with short, clear health info. Use emojis like π. For serious issues, suggest seeing a doctor." | |
| full_input = f"{instruction}\nQuestion: {prompt} Answer:" | |
| else: | |
| # Causal LM format | |
| full_input = f"{MEDICAL_SYSTEM_PROMPT}\n\nPatient/User: {prompt}\n" | |
| print(f"Full input: {full_input}") | |
| # Tokenize input with proper truncation (reduced max_length for T5) | |
| inputs = tokenizer( | |
| full_input, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding=True | |
| ) | |
| # Move to appropriate device | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Generation parameters - optimized for T5 | |
| generation_kwargs = { | |
| "max_new_tokens": min(max_tokens, 256), # Reduced to 256 for control | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "do_sample": GENERATION_DEFAULTS["do_sample"], | |
| "repetition_penalty": GENERATION_DEFAULTS["repetition_penalty"], | |
| "no_repeat_ngram_size": GENERATION_DEFAULTS["no_repeat_ngram_size"] | |
| } | |
| # Add pad_token_id for non-T5 models | |
| if not ("flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower()): | |
| generation_kwargs["pad_token_id"] = tokenizer.eos_token_id | |
| print(f"Generating with kwargs: {generation_kwargs}") | |
| # Generate response | |
| print(f"π€ Generating response with {current_model_name}...") | |
| import time | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, **generation_kwargs) | |
| generation_time = time.time() - start_time | |
| print(f"β±οΈ Generation completed in {generation_time:.2f} seconds") | |
| # Decode response - different handling for T5 vs causal models | |
| if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower(): | |
| # T5 generates only the answer, no need to remove prompt | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
| else: | |
| # Causal models generate prompt + answer, need to remove prompt | |
| full_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| response = full_response.replace(full_input, "").strip() | |
| print(f"Generated response: {response}") | |
| # Clean up response | |
| if not response or len(response.strip()) < 10: | |
| response = "Sorry, I couldn't process that. Try again or see a doctor. π" | |
| print(f"β Generated response length: {len(response)} characters") | |
| print(f"π Response preview: {response[:150]}{'...' if len(response) > 150 else ''}") | |
| # Clean up memory | |
| del inputs, outputs | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() # Force garbage collection | |
| print(f"π Generated response: {response}") | |
| return response | |
| except Exception as e: | |
| error_msg = f"Error generating response: {str(e)}" | |
| print(error_msg) | |
| return f"β οΈ I encountered a technical issue while processing your request. Please try again or rephrase your question. If the problem persists, consider consulting a healthcare professional directly." | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| """Main response function for Gradio ChatInterface""" | |
| if not message or not message.strip(): | |
| return "Please enter a medical question or concern." | |
| # Add a disclaimer for first-time users | |
| disclaimer = "\n\nβ οΈ **Medical Disclaimer**: This AI provides general health information only. Always consult healthcare professionals for medical advice, diagnosis, or treatment." | |
| try: | |
| # Generate response | |
| response = generate_response( | |
| message.strip(), | |
| max_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p) | |
| ) | |
| # Add disclaimer to response | |
| if "disclaimer" not in response.lower() and "consult" not in response.lower(): | |
| response += disclaimer | |
| return response | |
| except Exception as e: | |
| error_msg = f"System error: {str(e)}" | |
| print(error_msg) | |
| return f"β οΈ System temporarily unavailable. Please try again later or consult a healthcare professional directly.{disclaimer}" | |
| def get_model_info(): | |
| """Get information about the currently loaded model""" | |
| if current_model_name: | |
| return f"Currently using: {current_model_name}" | |
| return "No model loaded" | |
| # Load model on startup | |
| print("π₯ Initializing MedLLaMA2 Medical Chatbot...") | |
| print("π Loading medical language model...") | |
| model_loaded = load_model() | |
| if model_loaded: | |
| print(f"β Ready! {get_model_info()}") | |
| else: | |
| print("β οΈ WARNING: Model failed to load. The app will run but responses may be limited.") | |
| # Create Gradio interface with configuration | |
| demo = gr.ChatInterface( | |
| respond, | |
| title=UI_CONFIG["title"], | |
| description=UI_CONFIG["description"], | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value=MEDICAL_SYSTEM_PROMPT, | |
| label="System Instructions", | |
| lines=4, | |
| interactive=False # Make it read-only to prevent tampering | |
| ), | |
| gr.Slider( | |
| minimum=UI_CONFIG["max_tokens_range"][0], | |
| maximum=UI_CONFIG["max_tokens_range"][1], | |
| value=GENERATION_DEFAULTS["max_new_tokens"], | |
| step=10, | |
| label="Max new tokens" | |
| ), | |
| gr.Slider( | |
| minimum=UI_CONFIG["temperature_range"][0], | |
| maximum=UI_CONFIG["temperature_range"][1], | |
| value=GENERATION_DEFAULTS["temperature"], | |
| step=0.1, | |
| label="Temperature (creativity)" | |
| ), | |
| gr.Slider( | |
| minimum=UI_CONFIG["top_p_range"][0], | |
| maximum=UI_CONFIG["top_p_range"][1], | |
| value=GENERATION_DEFAULTS["top_p"], | |
| step=0.05, | |
| label="Top-p (focus)", | |
| ), | |
| ], | |
| examples=[[example] for example in UI_CONFIG["examples"]], | |
| cache_examples=False, | |
| theme=gr.themes.Soft(), | |
| css=".gradio-container {max-width: 900px; margin: auto;}" | |
| ) | |
| # Add model info to the interface | |
| with demo: | |
| gr.HTML(f"<p style='text-align: center; color: #666; font-size: 0.9em;'>Model Status: {get_model_info()}</p>") | |
| if __name__ == "__main__": | |
| # For Hugging Face Spaces deployment | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True, | |
| debug=True | |
| ) | |