Deva1211 commited on
Commit
5bb3d19
·
1 Parent(s): e303824

Fixed memory issue

Browse files
Files changed (2) hide show
  1. app.py +48 -19
  2. config.py +2 -2
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  import logging
5
  import gc
6
  import warnings
@@ -27,8 +27,8 @@ def load_model(model_key=None):
27
  if model_key is None:
28
  model_key = DEFAULT_MODEL
29
 
30
- # Try to load models in order of preference
31
- model_keys_to_try = [model_key, "meditron", "dialogpt_medium", "dialogpt_small"]
32
 
33
  for key in model_keys_to_try:
34
  if key not in MODEL_CONFIGS:
@@ -80,7 +80,11 @@ def load_model(model_key=None):
80
  model_kwargs["device_map"] = None # Let it use CPU naturally
81
 
82
  print("Loading model...")
83
- model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
 
 
 
 
84
 
85
  current_model_name = model_name
86
  print(f"✅ Model loaded successfully: {model_name}")
@@ -113,7 +117,14 @@ def generate_response(prompt, max_tokens=None, temperature=None, top_p=None):
113
  top_p = top_p or GENERATION_DEFAULTS["top_p"]
114
 
115
  try:
116
- full_prompt = f"{MEDICAL_SYSTEM_PROMPT}\n\nPatient/User: {prompt}\n"
 
 
 
 
 
 
 
117
  print(f"Full prompt: {full_prompt}")
118
 
119
  # Tokenize input with proper truncation
@@ -121,7 +132,7 @@ def generate_response(prompt, max_tokens=None, temperature=None, top_p=None):
121
  full_prompt,
122
  return_tensors="pt",
123
  truncation=True,
124
- max_length=512, # Reduced for DialoGPT
125
  padding=True
126
  )
127
 
@@ -129,16 +140,28 @@ def generate_response(prompt, max_tokens=None, temperature=None, top_p=None):
129
  device = next(model.parameters()).device
130
  inputs = {k: v.to(device) for k, v in inputs.items()}
131
 
132
- # Generation parameters
133
- generation_kwargs = {
134
- "max_new_tokens": min(max_tokens, 1024),
135
- "temperature": temperature,
136
- "top_p": top_p,
137
- "do_sample": GENERATION_DEFAULTS["do_sample"],
138
- "pad_token_id": tokenizer.eos_token_id,
139
- "repetition_penalty": GENERATION_DEFAULTS["repetition_penalty"],
140
- "no_repeat_ngram_size": GENERATION_DEFAULTS["no_repeat_ngram_size"]
141
- }
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  print(f"Generating with kwargs: {generation_kwargs}")
144
 
@@ -153,9 +176,15 @@ def generate_response(prompt, max_tokens=None, temperature=None, top_p=None):
153
  generation_time = time.time() - start_time
154
  print(f"⏱️ Generation completed in {generation_time:.2f} seconds")
155
 
156
- # Decode response and extract new content
157
- full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
158
- response = full_response.replace(full_prompt, "").strip()
 
 
 
 
 
 
159
  print(f"Generated response: {response}")
160
 
161
  # Clean up response
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig
4
  import logging
5
  import gc
6
  import warnings
 
27
  if model_key is None:
28
  model_key = DEFAULT_MODEL
29
 
30
+ # Try to load models in order of preference - prioritize lightweight models
31
+ model_keys_to_try = [model_key, "flan_t5_small", "dialogpt_medium", "meditron"]
32
 
33
  for key in model_keys_to_try:
34
  if key not in MODEL_CONFIGS:
 
80
  model_kwargs["device_map"] = None # Let it use CPU naturally
81
 
82
  print("Loading model...")
83
+ # Use appropriate model class based on model type
84
+ if "flan-t5" in model_name.lower() or "t5" in model_name.lower():
85
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **model_kwargs)
86
+ else:
87
+ model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
88
 
89
  current_model_name = model_name
90
  print(f"✅ Model loaded successfully: {model_name}")
 
117
  top_p = top_p or GENERATION_DEFAULTS["top_p"]
118
 
119
  try:
120
+ # Format prompt based on model type
121
+ if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower():
122
+ # T5 instruction format
123
+ full_prompt = f"{MEDICAL_SYSTEM_PROMPT}\n\nQuestion: {prompt}\nAnswer:"
124
+ else:
125
+ # Causal LM format
126
+ full_prompt = f"{MEDICAL_SYSTEM_PROMPT}\n\nPatient/User: {prompt}\n"
127
+
128
  print(f"Full prompt: {full_prompt}")
129
 
130
  # Tokenize input with proper truncation
 
132
  full_prompt,
133
  return_tensors="pt",
134
  truncation=True,
135
+ max_length=512,
136
  padding=True
137
  )
138
 
 
140
  device = next(model.parameters()).device
141
  inputs = {k: v.to(device) for k, v in inputs.items()}
142
 
143
+ # Generation parameters - different for T5 vs causal models
144
+ if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower():
145
+ # T5 seq2seq generation parameters
146
+ generation_kwargs = {
147
+ "max_new_tokens": min(max_tokens, 100),
148
+ "temperature": temperature,
149
+ "top_p": top_p,
150
+ "do_sample": GENERATION_DEFAULTS["do_sample"],
151
+ "repetition_penalty": GENERATION_DEFAULTS["repetition_penalty"],
152
+ "early_stopping": True
153
+ }
154
+ else:
155
+ # Causal LM generation parameters
156
+ generation_kwargs = {
157
+ "max_new_tokens": min(max_tokens, 1024),
158
+ "temperature": temperature,
159
+ "top_p": top_p,
160
+ "do_sample": GENERATION_DEFAULTS["do_sample"],
161
+ "pad_token_id": tokenizer.eos_token_id,
162
+ "repetition_penalty": GENERATION_DEFAULTS["repetition_penalty"],
163
+ "no_repeat_ngram_size": GENERATION_DEFAULTS["no_repeat_ngram_size"]
164
+ }
165
 
166
  print(f"Generating with kwargs: {generation_kwargs}")
167
 
 
176
  generation_time = time.time() - start_time
177
  print(f"⏱️ Generation completed in {generation_time:.2f} seconds")
178
 
179
+ # Decode response - different handling for T5 vs causal models
180
+ if "flan-t5" in current_model_name.lower() or "t5" in current_model_name.lower():
181
+ # T5 generates only the answer, no need to remove prompt
182
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
183
+ else:
184
+ # Causal models generate prompt + answer, need to remove prompt
185
+ full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
186
+ response = full_response.replace(full_prompt, "").strip()
187
+
188
  print(f"Generated response: {response}")
189
 
190
  # Clean up response
config.py CHANGED
@@ -16,8 +16,8 @@ MODEL_CONFIGS = {
16
  }
17
  }
18
 
19
- # Default model to use - prioritize medical model
20
- DEFAULT_MODEL = "meditron"
21
 
22
  # Model loading settings (optimized for CPU)
23
  MODEL_SETTINGS = {
 
16
  }
17
  }
18
 
19
+ # Default model to use - lightweight for 16GB memory limit
20
+ DEFAULT_MODEL = "flan_t5_small"
21
 
22
  # Model loading settings (optimized for CPU)
23
  MODEL_SETTINGS = {