File size: 12,058 Bytes
d8a1279
37244c4
5bb3d19
37244c4
 
 
 
81e2904
37244c4
d8a1279
81e2904
 
 
37244c4
 
 
d8a1279
37244c4
 
 
 
 
 
 
 
 
 
 
 
5bb3d19
 
37244c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b86bd14
37244c4
 
 
 
 
 
 
 
 
f29339b
 
 
37244c4
 
 
 
 
 
 
 
 
 
f29339b
37244c4
 
 
f29339b
 
 
 
37244c4
 
5bb3d19
 
 
 
 
37244c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e303824
 
 
 
37244c4
 
 
 
 
 
 
 
 
5bb3d19
 
eb53fd2
 
 
5bb3d19
 
eb53fd2
5bb3d19
eb53fd2
37244c4
eb53fd2
37244c4
eb53fd2
37244c4
 
5bb3d19
37244c4
 
 
 
 
 
 
eb53fd2
 
 
 
 
 
 
 
 
 
 
 
 
37244c4
e303824
1f37230
37244c4
f29339b
 
 
 
37244c4
 
 
f29339b
 
 
5bb3d19
 
 
 
 
 
 
eb53fd2
5bb3d19
e303824
37244c4
e303824
 
 
37244c4
f29339b
b09dbcd
f29339b
37244c4
 
 
 
f29339b
37244c4
1f37230
37244c4
 
 
 
 
 
d8a1279
 
 
 
 
 
 
 
 
37244c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8a1279
37244c4
 
 
 
 
d8a1279
37244c4
 
 
 
d8a1279
37244c4
 
 
 
d8a1279
37244c4
d8a1279
 
37244c4
 
d8a1279
37244c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8a1279
37244c4
 
 
d8a1279
37244c4
d8a1279
 
37244c4
 
 
 
d8a1279
 
37244c4
 
 
d8a1279
 
e75df51
 
b4ee383
 
 
 
 
e75df51
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
import logging
import gc
import warnings
import os
from huggingface_hub import login
from config import MODEL_CONFIGS, DEFAULT_MODEL, MODEL_SETTINGS, GENERATION_DEFAULTS, MEDICAL_SYSTEM_PROMPT, UI_CONFIG

# Login with the secret token
login(token=os.getenv("HF_TOKEN"))

# Suppress warnings
warnings.filterwarnings("ignore")
logging.getLogger("transformers").setLevel(logging.ERROR)

# Global variables for model and tokenizer
model = None
tokenizer = None
current_model_name = None

def load_model(model_key=None):
    """Load the specified medical model with optimizations for Hugging Face Spaces"""
    global model, tokenizer, current_model_name
    
    if model_key is None:
        model_key = DEFAULT_MODEL
    
    # Try to load models in order of preference - prioritize lightweight models
    model_keys_to_try = [model_key, "flan_t5_small", "dialogpt_medium", "meditron"]
    
    for key in model_keys_to_try:
        if key not in MODEL_CONFIGS:
            continue
            
        try:
            model_config = MODEL_CONFIGS[key]
            model_name = model_config["name"]
            print(f"Attempting to load model: {model_name} ({model_config['description']})")
            
            # Load tokenizer first
            print("Loading tokenizer...")
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=MODEL_SETTINGS["trust_remote_code"],
                padding_side="left"
            )
            
            # Add pad token if it  doesn't exist
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
            
            # Configure quantization for memory efficiency (only for larger models)
            model_kwargs = {
                "trust_remote_code": MODEL_SETTINGS["trust_remote_code"],
                "low_cpu_mem_usage": MODEL_SETTINGS["low_cpu_mem_usage"]
            }
            
            # Optimized loading for CPU performance
            if MODEL_SETTINGS["use_quantization"] and torch.cuda.is_available() and key in ["medllama2", "meditron", "clinical_camel"]:
                # Only use quantization on GPU for larger models
                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.float16,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_use_double_quant=True,
                )
                model_kwargs["quantization_config"] = quantization_config
                model_kwargs["torch_dtype"] = torch.float16
                model_kwargs["device_map"] = MODEL_SETTINGS["device_map"]
            else:
                # For CPU or smaller models, use optimized settings
                if torch.cuda.is_available():
                    model_kwargs["torch_dtype"] = torch.float16
                    model_kwargs["device_map"] = "auto"
                else:
                    # CPU-optimized settings
                    model_kwargs["torch_dtype"] = torch.float32  # Use float32 on CPU
                    model_kwargs["device_map"] = None  # Let it use CPU naturally
            
            print("Loading model...")
            # Use appropriate model class based on model type
            if "flan-t5" in model_name.lower() or "t5" in model_name.lower():
                model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs)
            else:
                model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
            
            current_model_name = model_name
            print(f"βœ… Model loaded successfully: {model_name}")
            return True
            
        except Exception as e:
            print(f"❌ Failed to load {key}: {str(e)}")
            # Clean up on failure
            model = None
            tokenizer = None
            continue
    
    print("❌ All model loading attempts failed")
    return False

def generate_response(prompt, max_tokens=None, temperature=None, top_p=None):
    """Generate response using the loaded model"""
    global model, tokenizer, current_model_name
    
    print(f"Starting generation for prompt: {prompt}")
    if not prompt or not prompt.strip():
        return "Please enter a question. 😊"
    
    if model is None or tokenizer is None:
        return "❌ Model not loaded. Please wait for initialization or try restarting the space."
    
    # Use defaults if not specified
    max_tokens = max_tokens or GENERATION_DEFAULTS["max_new_tokens"]
    temperature = temperature or GENERATION_DEFAULTS["temperature"]
    top_p = top_p or GENERATION_DEFAULTS["top_p"]
    
    try:
        # Format prompt based on model type
        if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower():
            # Use a concise instruction prefix for T5
            instruction = "You are a friendly medical assistant. Answer with short, clear health info. Use emojis like 😊. For serious issues, suggest seeing a doctor."
            full_input = f"{instruction}\nQuestion: {prompt} Answer:"
        else:
            # Causal LM format
            full_input = f"{MEDICAL_SYSTEM_PROMPT}\n\nPatient/User: {prompt}\n"
        
        print(f"Full input: {full_input}")
        
        # Tokenize input with proper truncation (reduced max_length for T5)
        inputs = tokenizer(
            full_input, 
            return_tensors="pt", 
            truncation=True, 
            max_length=512,
            padding=True
        )
        
        # Move to appropriate device
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Generation parameters - optimized for T5
        generation_kwargs = {
            "max_new_tokens": min(max_tokens, 256),  # Reduced to 256 for control
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": GENERATION_DEFAULTS["do_sample"],
            "repetition_penalty": GENERATION_DEFAULTS["repetition_penalty"],
            "no_repeat_ngram_size": GENERATION_DEFAULTS["no_repeat_ngram_size"]
        }
        
        # Add pad_token_id for non-T5 models
        if not ("flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower()):
            generation_kwargs["pad_token_id"] = tokenizer.eos_token_id
        
        print(f"Generating with kwargs: {generation_kwargs}")
        
        # Generate response
        print(f"πŸ€– Generating response with {current_model_name}...")
        import time
        start_time = time.time()
        
        with torch.no_grad():
            outputs = model.generate(**inputs, **generation_kwargs)
        
        generation_time = time.time() - start_time
        print(f"⏱️ Generation completed in {generation_time:.2f} seconds")
        
        # Decode response - different handling for T5 vs causal models
        if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower():
            # T5 generates only the answer, no need to remove prompt
            response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        else:
            # Causal models generate prompt + answer, need to remove prompt
            full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = full_response.replace(full_input, "").strip()
        
        print(f"Generated response: {response}")
        
        # Clean up response
        if not response or len(response.strip()) < 10:
            response = "Sorry, I couldn't process that. Try again or see a doctor. 😊"
        
        print(f"βœ… Generated response length: {len(response)} characters")
        print(f"πŸ“„ Response preview: {response[:150]}{'...' if len(response) > 150 else ''}")
        
        # Clean up memory
        del inputs, outputs
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()  # Force garbage collection
        
        print(f"πŸ“œ Generated response: {response}")
        return response
        
    except Exception as e:
        error_msg = f"Error generating response: {str(e)}"
        print(error_msg)
        return f"⚠️ I encountered a technical issue while processing your request. Please try again or rephrase your question. If the problem persists, consider consulting a healthcare professional directly."

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    """Main response function for Gradio ChatInterface"""
    if not message or not message.strip():
        return "Please enter a medical question or concern."
    
    # Add a disclaimer for first-time users
    disclaimer = "\n\n⚠️ **Medical Disclaimer**: This AI provides general health information only. Always consult healthcare professionals for medical advice, diagnosis, or treatment."
    
    try:
        # Generate response
        response = generate_response(
            message.strip(),
            max_tokens=int(max_tokens),
            temperature=float(temperature),
            top_p=float(top_p)
        )
        
        # Add disclaimer to response
        if "disclaimer" not in response.lower() and "consult" not in response.lower():
            response += disclaimer
        
        return response
        
    except Exception as e:
        error_msg = f"System error: {str(e)}"
        print(error_msg)
        return f"⚠️ System temporarily unavailable. Please try again later or consult a healthcare professional directly.{disclaimer}"

def get_model_info():
    """Get information about the currently loaded model"""
    if current_model_name:
        return f"Currently using: {current_model_name}"
    return "No model loaded"

# Load model on startup
print("πŸ₯ Initializing MedLLaMA2 Medical Chatbot...")
print("πŸ“‹ Loading medical language model...")
model_loaded = load_model()

if model_loaded:
    print(f"βœ… Ready! {get_model_info()}")
else:
    print("⚠️ WARNING: Model failed to load. The app will run but responses may be limited.")

# Create Gradio interface with configuration
demo = gr.ChatInterface(
    respond,
    title=UI_CONFIG["title"],
    description=UI_CONFIG["description"],
    additional_inputs=[
        gr.Textbox(
            value=MEDICAL_SYSTEM_PROMPT,
            label="System Instructions",
            lines=4,
            interactive=False  # Make it read-only to prevent tampering
        ),
        gr.Slider(
            minimum=UI_CONFIG["max_tokens_range"][0], 
            maximum=UI_CONFIG["max_tokens_range"][1], 
            value=GENERATION_DEFAULTS["max_new_tokens"], 
            step=10, 
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=UI_CONFIG["temperature_range"][0], 
            maximum=UI_CONFIG["temperature_range"][1], 
            value=GENERATION_DEFAULTS["temperature"], 
            step=0.1, 
            label="Temperature (creativity)"
        ),
        gr.Slider(
            minimum=UI_CONFIG["top_p_range"][0],
            maximum=UI_CONFIG["top_p_range"][1],
            value=GENERATION_DEFAULTS["top_p"],
            step=0.05,
            label="Top-p (focus)",
        ),
    ],
    examples=[[example] for example in UI_CONFIG["examples"]],
    cache_examples=False,
    theme=gr.themes.Soft(),
    css=".gradio-container {max-width: 900px; margin: auto;}"
)

# Add model info to the interface
with demo:
    gr.HTML(f"<p style='text-align: center; color: #666; font-size: 0.9em;'>Model Status: {get_model_info()}</p>")

if __name__ == "__main__":
    # For Hugging Face Spaces deployment
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True,
        show_error=True,
        debug=True
    )