medical_model / app.py
Deva1211's picture
Fixing issues
eb53fd2
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
import logging
import gc
import warnings
import os
from huggingface_hub import login
from config import MODEL_CONFIGS, DEFAULT_MODEL, MODEL_SETTINGS, GENERATION_DEFAULTS, MEDICAL_SYSTEM_PROMPT, UI_CONFIG
# Login with the secret token
login(token=os.getenv("HF_TOKEN"))
# Suppress warnings
warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)
# Global variables for model and tokenizer
model = None
tokenizer = None
current_model_name = None
def load_model(model_key=None):
"""Load the specified medical model with optimizations for Hugging Face Spaces"""
global model, tokenizer, current_model_name
if model_key is None:
model_key = DEFAULT_MODEL
# Try to load models in order of preference - prioritize lightweight models
model_keys_to_try = [model_key, "flan_t5_small", "dialogpt_medium", "meditron"]
for key in model_keys_to_try:
if key not in MODEL_CONFIGS:
continue
try:
model_config = MODEL_CONFIGS[key]
model_name = model_config["name"]
print(f"Attempting to load model: {model_name} ({model_config['description']})")
# Load tokenizer first
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=MODEL_SETTINGS["trust_remote_code"],
padding_side="left"
)
# Add pad token if it doesn't exist
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Configure quantization for memory efficiency (only for larger models)
model_kwargs = {
"trust_remote_code": MODEL_SETTINGS["trust_remote_code"],
"low_cpu_mem_usage": MODEL_SETTINGS["low_cpu_mem_usage"]
}
# Optimized loading for CPU performance
if MODEL_SETTINGS["use_quantization"] and torch.cuda.is_available() and key in ["medllama2", "meditron", "clinical_camel"]:
# Only use quantization on GPU for larger models
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
model_kwargs["quantization_config"] = quantization_config
model_kwargs["torch_dtype"] = torch.float16
model_kwargs["device_map"] = MODEL_SETTINGS["device_map"]
else:
# For CPU or smaller models, use optimized settings
if torch.cuda.is_available():
model_kwargs["torch_dtype"] = torch.float16
model_kwargs["device_map"] = "auto"
else:
# CPU-optimized settings
model_kwargs["torch_dtype"] = torch.float32 # Use float32 on CPU
model_kwargs["device_map"] = None # Let it use CPU naturally
print("Loading model...")
# Use appropriate model class based on model type
if "flan-t5" in model_name.lower() or "t5" in model_name.lower():
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
current_model_name = model_name
print(f"βœ… Model loaded successfully: {model_name}")
return True
except Exception as e:
print(f"❌ Failed to load {key}: {str(e)}")
# Clean up on failure
model = None
tokenizer = None
continue
print("❌ All model loading attempts failed")
return False
def generate_response(prompt, max_tokens=None, temperature=None, top_p=None):
"""Generate response using the loaded model"""
global model, tokenizer, current_model_name
print(f"Starting generation for prompt: {prompt}")
if not prompt or not prompt.strip():
return "Please enter a question. 😊"
if model is None or tokenizer is None:
return "❌ Model not loaded. Please wait for initialization or try restarting the space."
# Use defaults if not specified
max_tokens = max_tokens or GENERATION_DEFAULTS["max_new_tokens"]
temperature = temperature or GENERATION_DEFAULTS["temperature"]
top_p = top_p or GENERATION_DEFAULTS["top_p"]
try:
# Format prompt based on model type
if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower():
# Use a concise instruction prefix for T5
instruction = "You are a friendly medical assistant. Answer with short, clear health info. Use emojis like 😊. For serious issues, suggest seeing a doctor."
full_input = f"{instruction}\nQuestion: {prompt} Answer:"
else:
# Causal LM format
full_input = f"{MEDICAL_SYSTEM_PROMPT}\n\nPatient/User: {prompt}\n"
print(f"Full input: {full_input}")
# Tokenize input with proper truncation (reduced max_length for T5)
inputs = tokenizer(
full_input,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Move to appropriate device
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generation parameters - optimized for T5
generation_kwargs = {
"max_new_tokens": min(max_tokens, 256), # Reduced to 256 for control
"temperature": temperature,
"top_p": top_p,
"do_sample": GENERATION_DEFAULTS["do_sample"],
"repetition_penalty": GENERATION_DEFAULTS["repetition_penalty"],
"no_repeat_ngram_size": GENERATION_DEFAULTS["no_repeat_ngram_size"]
}
# Add pad_token_id for non-T5 models
if not ("flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower()):
generation_kwargs["pad_token_id"] = tokenizer.eos_token_id
print(f"Generating with kwargs: {generation_kwargs}")
# Generate response
print(f"πŸ€– Generating response with {current_model_name}...")
import time
start_time = time.time()
with torch.no_grad():
outputs = model.generate(**inputs, **generation_kwargs)
generation_time = time.time() - start_time
print(f"⏱️ Generation completed in {generation_time:.2f} seconds")
# Decode response - different handling for T5 vs causal models
if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower():
# T5 generates only the answer, no need to remove prompt
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
else:
# Causal models generate prompt + answer, need to remove prompt
full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = full_response.replace(full_input, "").strip()
print(f"Generated response: {response}")
# Clean up response
if not response or len(response.strip()) < 10:
response = "Sorry, I couldn't process that. Try again or see a doctor. 😊"
print(f"βœ… Generated response length: {len(response)} characters")
print(f"πŸ“„ Response preview: {response[:150]}{'...' if len(response) > 150 else ''}")
# Clean up memory
del inputs, outputs
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect() # Force garbage collection
print(f"πŸ“œ Generated response: {response}")
return response
except Exception as e:
error_msg = f"Error generating response: {str(e)}"
print(error_msg)
return f"⚠️ I encountered a technical issue while processing your request. Please try again or rephrase your question. If the problem persists, consider consulting a healthcare professional directly."
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""Main response function for Gradio ChatInterface"""
if not message or not message.strip():
return "Please enter a medical question or concern."
# Add a disclaimer for first-time users
disclaimer = "\n\n⚠️ **Medical Disclaimer**: This AI provides general health information only. Always consult healthcare professionals for medical advice, diagnosis, or treatment."
try:
# Generate response
response = generate_response(
message.strip(),
max_tokens=int(max_tokens),
temperature=float(temperature),
top_p=float(top_p)
)
# Add disclaimer to response
if "disclaimer" not in response.lower() and "consult" not in response.lower():
response += disclaimer
return response
except Exception as e:
error_msg = f"System error: {str(e)}"
print(error_msg)
return f"⚠️ System temporarily unavailable. Please try again later or consult a healthcare professional directly.{disclaimer}"
def get_model_info():
"""Get information about the currently loaded model"""
if current_model_name:
return f"Currently using: {current_model_name}"
return "No model loaded"
# Load model on startup
print("πŸ₯ Initializing MedLLaMA2 Medical Chatbot...")
print("πŸ“‹ Loading medical language model...")
model_loaded = load_model()
if model_loaded:
print(f"βœ… Ready! {get_model_info()}")
else:
print("⚠️ WARNING: Model failed to load. The app will run but responses may be limited.")
# Create Gradio interface with configuration
demo = gr.ChatInterface(
respond,
title=UI_CONFIG["title"],
description=UI_CONFIG["description"],
additional_inputs=[
gr.Textbox(
value=MEDICAL_SYSTEM_PROMPT,
label="System Instructions",
lines=4,
interactive=False # Make it read-only to prevent tampering
),
gr.Slider(
minimum=UI_CONFIG["max_tokens_range"][0],
maximum=UI_CONFIG["max_tokens_range"][1],
value=GENERATION_DEFAULTS["max_new_tokens"],
step=10,
label="Max new tokens"
),
gr.Slider(
minimum=UI_CONFIG["temperature_range"][0],
maximum=UI_CONFIG["temperature_range"][1],
value=GENERATION_DEFAULTS["temperature"],
step=0.1,
label="Temperature (creativity)"
),
gr.Slider(
minimum=UI_CONFIG["top_p_range"][0],
maximum=UI_CONFIG["top_p_range"][1],
value=GENERATION_DEFAULTS["top_p"],
step=0.05,
label="Top-p (focus)",
),
],
examples=[[example] for example in UI_CONFIG["examples"]],
cache_examples=False,
theme=gr.themes.Soft(),
css=".gradio-container {max-width: 900px; margin: auto;}"
)
# Add model info to the interface
with demo:
gr.HTML(f"<p style='text-align: center; color: #666; font-size: 0.9em;'>Model Status: {get_model_info()}</p>")
if __name__ == "__main__":
# For Hugging Face Spaces deployment
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True,
debug=True
)