yashbyname commited on
Commit
f3a1e2d
·
verified ·
1 Parent(s): f60e215

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -171
app.py CHANGED
@@ -1,168 +1,11 @@
1
- # import requests
2
- # import numpy as np
3
- # import tensorflow as tf
4
- # import tensorflow_hub as hub
5
- # import gradio as gr
6
- # from PIL import Image
7
-
8
- # # Load models
9
- # #model_initial = keras.models.load_model(
10
- # # "models/initial_model.h5", custom_objects={'KerasLayer': hub.KerasLayer}
11
- # #)
12
- # #model_tumor = keras.models.load_model(
13
- # # "models/model_tumor.h5", custom_objects={'KerasLayer': hub.KerasLayer}
14
- # #)
15
- # #model_stroke = keras.models.load_model(
16
- # # "models/model_stroke.h5", custom_objects={'KerasLayer': hub.KerasLayer}
17
- # #)
18
- # #model_alzheimer = keras.models.load_model(
19
- # # "models/model_alzheimer.h5", custom_objects={'KerasLayer': hub.KerasLayer}
20
-
21
- # # API key and user ID for on-demand
22
- # api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
23
- # external_user_id = 'plugin-1717464304'
24
-
25
- # # Step 1: Create a chat session with the API
26
- # def create_chat_session():
27
- # create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
28
- # create_session_headers = {
29
- # 'apikey': api_key
30
- # }
31
- # create_session_body = {
32
- # "pluginIds": [],
33
- # "externalUserId": external_user_id
34
- # }
35
- # response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body)
36
- # response_data = response.json()
37
- # session_id = response_data['data']['id']
38
- # return session_id
39
-
40
- # # Step 2: Submit query to the API
41
- # def submit_query(session_id, query):
42
- # submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
43
- # submit_query_headers = {
44
- # 'apikey': api_key
45
- # }
46
- # submit_query_body = {
47
- # "endpointId": "predefined-openai-gpt4o",
48
- # "query": query,
49
- # "pluginIds": ["plugin-1712327325", "plugin-1713962163"],
50
- # "responseMode": "sync"
51
- # }
52
- # response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body)
53
- # return response.json()
54
-
55
- # # Combined disease model (placeholder)
56
- # class CombinedDiseaseModel(tf.keras.Model):
57
- # def __init__(self, model_initial, model_alzheimer, model_tumor, model_stroke):
58
- # super(CombinedDiseaseModel, self).__init__()
59
- # self.model_initial = model_initial
60
- # self.model_alzheimer = model_alzheimer
61
- # self.model_tumor = model_tumor
62
- # self.model_stroke = model_stroke
63
- # self.disease_labels = ["Alzheimer's", 'No Disease', 'Stroke', 'Tumor']
64
- # self.sub_models = {
65
- # "Alzheimer's": model_alzheimer,
66
- # 'Tumor': model_tumor,
67
- # 'Stroke': model_stroke
68
- # }
69
-
70
- # def call(self, inputs):
71
- # initial_probs = self.model_initial(inputs, training=False)
72
- # main_disease_idx = tf.argmax(initial_probs, axis=1)
73
- # main_disease = self.disease_labels[main_disease_idx[0].numpy()]
74
- # main_disease_prob = initial_probs[0, main_disease_idx[0]].numpy()
75
-
76
- # if main_disease == 'No Disease':
77
- # sub_category = "No Disease"
78
- # sub_category_prob = main_disease_prob
79
- # else:
80
- # sub_model = self.sub_models[main_disease]
81
- # sub_category_pred = sub_model(inputs, training=False)
82
- # sub_category = tf.argmax(sub_category_pred, axis=1).numpy()[0]
83
- # sub_category_prob = sub_category_pred[0, sub_category].numpy()
84
-
85
- # if main_disease == "Alzheimer's":
86
- # sub_category_label = ['Very Mild', 'Mild', 'Moderate']
87
- # elif main_disease == 'Tumor':
88
- # sub_category_label = ['Glioma', 'Meningioma', 'Pituitary']
89
- # elif main_disease == 'Stroke':
90
- # sub_category_label = ['Ischemic', 'Hemorrhagic']
91
-
92
- # sub_category = sub_category_label[sub_category]
93
-
94
- # return f"The MRI image shows {main_disease} with a probability of {main_disease_prob*100:.2f}%.\n" \
95
- # f"The subcategory of {main_disease} is {sub_category} with a probability of {sub_category_prob*100:.2f}%."
96
-
97
- # # Placeholder function to process images
98
- # def process_image(image):
99
- # image = image.resize((256, 256))
100
- # image.convert("RGB")
101
- # image_array = np.array(image) / 255.0
102
- # image_array = np.expand_dims(image_array, axis=0)
103
- # # Prediction logic here
104
- # # predictions = cnn_model(image_array)
105
- # return "Mock prediction: Disease identified with a probability of 85%."
106
-
107
- # # Function to handle patient info, query, and image processing
108
- # def gradio_interface(patient_info, query_type, image):
109
- # if image is not None:
110
- # image_response = process_image(image)
111
-
112
- # # Call LLM with patient info and query
113
- # session_id = create_chat_session()
114
- # query = f"Patient Info: {patient_info}\nQuery Type: {query_type}"
115
- # llm_response = submit_query(session_id, query)
116
-
117
- # # Debug: Print the full response to inspect it
118
- # print("LLM Response:", llm_response) # This will print the full response for inspection
119
-
120
- # # Safely handle 'message' if it exists
121
- # message = llm_response.get('data', {}).get('message', 'No message returned from LLM')
122
-
123
- # # Check if message is empty and print the complete response if necessary
124
- # if message == 'No message returned from LLM':
125
- # print("Full LLM Response Data:", llm_response) # Inspect the full LLM response for any helpful info
126
-
127
- # response = f"Patient Info: {patient_info}\nQuery Type: {query_type}\n\n{image_response}\n\nLLM Response:\n{message}"
128
- # return response
129
- # else:
130
- # return "Please upload an image."
131
-
132
- # # Gradio interface
133
- # iface = gr.Interface(
134
- # fn=gradio_interface,
135
- # inputs=[
136
- # gr.Textbox(
137
- # label="Patient Information",
138
- # placeholder="Enter patient details here...",
139
- # lines=5,
140
- # max_lines=10
141
- # ),
142
- # gr.Textbox(
143
- # label="Query Type",
144
- # placeholder="Describe the type of diagnosis or information needed..."
145
- # ),
146
- # gr.Image(
147
- # type="pil",
148
- # label="Upload an MRI Image",
149
- # )
150
- # ],
151
- # outputs=gr.Textbox(label="Response", placeholder="The response will appear here..."),
152
- # title="Medical Diagnosis with MRI and LLM",
153
- # description="Upload MRI images and provide patient information for a combined CNN model and LLM analysis."
154
- # )
155
-
156
- # iface.launch()
157
-
158
-
159
-
160
-
161
-
162
  import requests
163
  import gradio as gr
164
  import logging
165
  import json
 
 
 
 
166
 
167
  # Set up logging
168
  logging.basicConfig(level=logging.INFO)
@@ -172,6 +15,39 @@ logger = logging.getLogger(__name__)
172
  api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
173
  external_user_id = 'plugin-1717464304'
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def create_chat_session():
176
  try:
177
  create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
@@ -192,7 +68,7 @@ def create_chat_session():
192
  logger.error(f"Error creating chat session: {str(e)}")
193
  raise
194
 
195
- def submit_query(session_id, query):
196
  try:
197
  submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
198
  submit_query_headers = {
@@ -200,10 +76,15 @@ def submit_query(session_id, query):
200
  'Content-Type': 'application/json'
201
  }
202
 
 
 
 
 
 
203
  structured_query = f"""
204
- Based on the following patient information, provide a detailed medical analysis in JSON format:
205
 
206
- {query}
207
 
208
  Return only valid JSON with these fields:
209
  - diagnosis_details
@@ -214,6 +95,7 @@ def submit_query(session_id, query):
214
  - additional_tests (array)
215
  - precautions (array)
216
  - follow_up (string)
 
217
  """
218
 
219
  submit_query_body = {
@@ -234,11 +116,9 @@ def submit_query(session_id, query):
234
  def extract_json_from_answer(answer):
235
  """Extract and clean JSON from the LLM response"""
236
  try:
237
- # First try to parse the answer directly
238
  return json.loads(answer)
239
  except json.JSONDecodeError:
240
  try:
241
- # If that fails, try to find JSON content and parse it
242
  start_idx = answer.find('{')
243
  end_idx = answer.rfind('}') + 1
244
  if start_idx != -1 and end_idx != 0:
@@ -248,10 +128,31 @@ def extract_json_from_answer(answer):
248
  logger.error("Failed to parse JSON from response")
249
  raise
250
 
251
- def gradio_interface(patient_info):
 
 
 
252
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  session_id = create_chat_session()
254
- llm_response = submit_query(session_id, patient_info)
 
255
 
256
  if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
257
  raise ValueError("Invalid response structure")
@@ -259,7 +160,7 @@ def gradio_interface(patient_info):
259
  # Extract and clean JSON from the response
260
  json_data = extract_json_from_answer(llm_response['data']['answer'])
261
 
262
- # Return clean JSON string without extra formatting
263
  return json.dumps(json_data)
264
 
265
  except Exception as e:
@@ -275,6 +176,11 @@ iface = gr.Interface(
275
  placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
276
  lines=5,
277
  max_lines=10
 
 
 
 
 
278
  )
279
  ],
280
  outputs=gr.Textbox(
@@ -283,8 +189,142 @@ iface = gr.Interface(
283
  lines=15
284
  ),
285
  title="Medical Diagnosis Assistant",
286
- description="Enter detailed patient information to receive a structured medical analysis in JSON format."
287
  )
288
 
289
  if __name__ == "__main__":
290
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  import gradio as gr
3
  import logging
4
  import json
5
+ import tensorflow as tf
6
+ import numpy as np
7
+ from PIL import Image
8
+ import io
9
 
10
  # Set up logging
11
  logging.basicConfig(level=logging.INFO)
 
15
  api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
16
  external_user_id = 'plugin-1717464304'
17
 
18
+ # Load the keras model
19
+ def load_model():
20
+ try:
21
+ model = tf.keras.models.load_model('model_epoch_01.h5.keras')
22
+ logger.info("Model loaded successfully")
23
+ return model
24
+ except Exception as e:
25
+ logger.error(f"Error loading model: {str(e)}")
26
+ raise
27
+
28
+ # Preprocess image for model
29
+ def preprocess_image(image):
30
+ try:
31
+ # Convert to numpy array if needed
32
+ if isinstance(image, Image.Image):
33
+ image = np.array(image)
34
+
35
+ # Resize image to match model's expected input shape
36
+ # Note: Adjust these dimensions to match your model's requirements
37
+ target_size = (224, 224) # Change this to match your model's input size
38
+ image = tf.image.resize(image, target_size)
39
+
40
+ # Normalize pixel values
41
+ image = image / 255.0
42
+
43
+ # Add batch dimension
44
+ image = np.expand_dims(image, axis=0)
45
+
46
+ return image
47
+ except Exception as e:
48
+ logger.error(f"Error preprocessing image: {str(e)}")
49
+ raise
50
+
51
  def create_chat_session():
52
  try:
53
  create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
 
68
  logger.error(f"Error creating chat session: {str(e)}")
69
  raise
70
 
71
+ def submit_query(session_id, query, image_analysis=None):
72
  try:
73
  submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
74
  submit_query_headers = {
 
76
  'Content-Type': 'application/json'
77
  }
78
 
79
+ # Include image analysis in the query if available
80
+ query_with_image = query
81
+ if image_analysis:
82
+ query_with_image += f"\n\nImage Analysis Results: {image_analysis}"
83
+
84
  structured_query = f"""
85
+ Based on the following patient information and image analysis, provide a detailed medical analysis in JSON format:
86
 
87
+ {query_with_image}
88
 
89
  Return only valid JSON with these fields:
90
  - diagnosis_details
 
95
  - additional_tests (array)
96
  - precautions (array)
97
  - follow_up (string)
98
+ - image_findings (object with prediction and confidence)
99
  """
100
 
101
  submit_query_body = {
 
116
  def extract_json_from_answer(answer):
117
  """Extract and clean JSON from the LLM response"""
118
  try:
 
119
  return json.loads(answer)
120
  except json.JSONDecodeError:
121
  try:
 
122
  start_idx = answer.find('{')
123
  end_idx = answer.rfind('}') + 1
124
  if start_idx != -1 and end_idx != 0:
 
128
  logger.error("Failed to parse JSON from response")
129
  raise
130
 
131
+ # Initialize the model
132
+ model = load_model()
133
+
134
+ def gradio_interface(patient_info, image):
135
  try:
136
+ # Process image if provided
137
+ image_analysis = None
138
+ if image is not None:
139
+ # Preprocess image
140
+ processed_image = preprocess_image(image)
141
+
142
+ # Get model prediction
143
+ prediction = model.predict(processed_image)
144
+
145
+ # Format prediction results
146
+ # Note: Adjust this based on your model's output format
147
+ image_analysis = {
148
+ "prediction": float(prediction[0][0]), # Adjust indexing based on your model's output
149
+ "confidence": float(prediction[0][0]) * 100 # Convert to percentage
150
+ }
151
+
152
+ # Create chat session and submit query
153
  session_id = create_chat_session()
154
+ llm_response = submit_query(session_id, patient_info,
155
+ json.dumps(image_analysis) if image_analysis else None)
156
 
157
  if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
158
  raise ValueError("Invalid response structure")
 
160
  # Extract and clean JSON from the response
161
  json_data = extract_json_from_answer(llm_response['data']['answer'])
162
 
163
+ # Return clean JSON string
164
  return json.dumps(json_data)
165
 
166
  except Exception as e:
 
176
  placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
177
  lines=5,
178
  max_lines=10
179
+ ),
180
+ gr.Image(
181
+ label="Medical Image",
182
+ type="numpy",
183
+ optional=True
184
  )
185
  ],
186
  outputs=gr.Textbox(
 
189
  lines=15
190
  ),
191
  title="Medical Diagnosis Assistant",
192
+ description="Enter patient information and optionally upload a medical image for analysis."
193
  )
194
 
195
  if __name__ == "__main__":
196
+ iface.launch()
197
+
198
+
199
+
200
+
201
+
202
+ # import requests
203
+ # import gradio as gr
204
+ # import logging
205
+ # import json
206
+
207
+ # # Set up logging
208
+ # logging.basicConfig(level=logging.INFO)
209
+ # logger = logging.getLogger(__name__)
210
+
211
+ # # API key and user ID for on-demand
212
+ # api_key = 'KGSjxB1uptfSk8I8A7ciCuNT9Xa3qWC3'
213
+ # external_user_id = 'plugin-1717464304'
214
+
215
+ # def create_chat_session():
216
+ # try:
217
+ # create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
218
+ # create_session_headers = {
219
+ # 'apikey': api_key,
220
+ # 'Content-Type': 'application/json'
221
+ # }
222
+ # create_session_body = {
223
+ # "pluginIds": [],
224
+ # "externalUserId": external_user_id
225
+ # }
226
+
227
+ # response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body)
228
+ # response.raise_for_status()
229
+ # return response.json()['data']['id']
230
+
231
+ # except requests.exceptions.RequestException as e:
232
+ # logger.error(f"Error creating chat session: {str(e)}")
233
+ # raise
234
+
235
+ # def submit_query(session_id, query):
236
+ # try:
237
+ # submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
238
+ # submit_query_headers = {
239
+ # 'apikey': api_key,
240
+ # 'Content-Type': 'application/json'
241
+ # }
242
+
243
+ # structured_query = f"""
244
+ # Based on the following patient information, provide a detailed medical analysis in JSON format:
245
+
246
+ # {query}
247
+
248
+ # Return only valid JSON with these fields:
249
+ # - diagnosis_details
250
+ # - probable_diagnoses (array)
251
+ # - treatment_plans (array)
252
+ # - lifestyle_modifications (array)
253
+ # - medications (array of objects with name and dosage)
254
+ # - additional_tests (array)
255
+ # - precautions (array)
256
+ # - follow_up (string)
257
+ # """
258
+
259
+ # submit_query_body = {
260
+ # "endpointId": "predefined-openai-gpt4o",
261
+ # "query": structured_query,
262
+ # "pluginIds": ["plugin-1712327325", "plugin-1713962163"],
263
+ # "responseMode": "sync"
264
+ # }
265
+
266
+ # response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body)
267
+ # response.raise_for_status()
268
+ # return response.json()
269
+
270
+ # except requests.exceptions.RequestException as e:
271
+ # logger.error(f"Error submitting query: {str(e)}")
272
+ # raise
273
+
274
+ # def extract_json_from_answer(answer):
275
+ # """Extract and clean JSON from the LLM response"""
276
+ # try:
277
+ # # First try to parse the answer directly
278
+ # return json.loads(answer)
279
+ # except json.JSONDecodeError:
280
+ # try:
281
+ # # If that fails, try to find JSON content and parse it
282
+ # start_idx = answer.find('{')
283
+ # end_idx = answer.rfind('}') + 1
284
+ # if start_idx != -1 and end_idx != 0:
285
+ # json_str = answer[start_idx:end_idx]
286
+ # return json.loads(json_str)
287
+ # except (json.JSONDecodeError, ValueError):
288
+ # logger.error("Failed to parse JSON from response")
289
+ # raise
290
+
291
+ # def gradio_interface(patient_info):
292
+ # try:
293
+ # session_id = create_chat_session()
294
+ # llm_response = submit_query(session_id, patient_info)
295
+
296
+ # if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
297
+ # raise ValueError("Invalid response structure")
298
+
299
+ # # Extract and clean JSON from the response
300
+ # json_data = extract_json_from_answer(llm_response['data']['answer'])
301
+
302
+ # # Return clean JSON string without extra formatting
303
+ # return json.dumps(json_data)
304
+
305
+ # except Exception as e:
306
+ # logger.error(f"Error in gradio_interface: {str(e)}")
307
+ # return json.dumps({"error": str(e)})
308
+
309
+ # # Gradio interface
310
+ # iface = gr.Interface(
311
+ # fn=gradio_interface,
312
+ # inputs=[
313
+ # gr.Textbox(
314
+ # label="Patient Information",
315
+ # placeholder="Enter patient details including: symptoms, medical history, current medications, age, gender, and any relevant test results...",
316
+ # lines=5,
317
+ # max_lines=10
318
+ # )
319
+ # ],
320
+ # outputs=gr.Textbox(
321
+ # label="Medical Analysis",
322
+ # placeholder="JSON analysis will appear here...",
323
+ # lines=15
324
+ # ),
325
+ # title="Medical Diagnosis Assistant",
326
+ # description="Enter detailed patient information to receive a structured medical analysis in JSON format."
327
+ # )
328
+
329
+ # if __name__ == "__main__":
330
+ # iface.launch()