kkruel8100 commited on
Commit
0c76a0d
1 Parent(s): 6ddce78

commit gradio app

Browse files
Files changed (2) hide show
  1. app.py +406 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[17]:
5
+
6
+
7
+ import pickle
8
+ from PIL import Image
9
+ import numpy as np
10
+ import gradio as gr
11
+ from pathlib import Path
12
+ from transformers import pipeline
13
+ from tensorflow.keras.models import load_model
14
+ import tensorflow as tf
15
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
16
+ from dotenv import load_dotenv
17
+ import openai
18
+ import os
19
+ from langchain.schema import HumanMessage, SystemMessage
20
+ from langchain_openai import ChatOpenAI
21
+
22
+
23
+ # In[18]:
24
+
25
+
26
+ # Set the model's file path
27
+ file_path = Path("models/model_adam_scaled.h5")
28
+
29
+ # Load the model to a new object
30
+ adam_5 = tf.keras.models.load_model(file_path)
31
+
32
+ # Load env variables
33
+ load_dotenv()
34
+
35
+ # Add your OpenAI API key here
36
+ openai_api_key = os.getenv("OPENAI_API_KEY")
37
+
38
+ print(f"OpenAI API Key Loaded: {openai_api_key is not None}")
39
+
40
+
41
+ # Load the model and tokenizer for translation
42
+ model = MBartForConditionalGeneration.from_pretrained(
43
+ "facebook/mbart-large-50-many-to-many-mmt"
44
+ )
45
+ tokenizer = MBart50TokenizerFast.from_pretrained(
46
+ "facebook/mbart-large-50-many-to-many-mmt"
47
+ )
48
+
49
+ # Set source language
50
+ tokenizer.src_lang = "en_XX"
51
+
52
+
53
+ # In[22]:
54
+
55
+
56
+ # Constants
57
+ # Language information MBart
58
+ language_info = [
59
+ "English (en_XX)",
60
+ "Arabic (ar_AR)",
61
+ "Czech (cs_CZ)",
62
+ "German (de_DE)",
63
+ "Spanish (es_XX)",
64
+ "Estonian (et_EE)",
65
+ "Finnish (fi_FI)",
66
+ "French (fr_XX)",
67
+ "Gujarati (gu_IN)",
68
+ "Hindi (hi_IN)",
69
+ "Italian (it_IT)",
70
+ "Japanese (ja_XX)",
71
+ "Kazakh (kk_KZ)",
72
+ "Korean (ko_KR)",
73
+ "Lithuanian (lt_LT)",
74
+ "Latvian (lv_LV)",
75
+ "Burmese (my_MM)",
76
+ "Nepali (ne_NP)",
77
+ "Dutch (nl_XX)",
78
+ "Romanian (ro_RO)",
79
+ "Russian (ru_RU)",
80
+ "Sinhala (si_LK)",
81
+ "Turkish (tr_TR)",
82
+ "Vietnamese (vi_VN)",
83
+ "Chinese (zh_CN)",
84
+ "Afrikaans (af_ZA)",
85
+ "Azerbaijani (az_AZ)",
86
+ "Bengali (bn_IN)",
87
+ "Persian (fa_IR)",
88
+ "Hebrew (he_IL)",
89
+ "Croatian (hr_HR)",
90
+ "Indonesian (id_ID)",
91
+ "Georgian (ka_GE)",
92
+ "Khmer (km_KH)",
93
+ "Macedonian (mk_MK)",
94
+ "Malayalam (ml_IN)",
95
+ "Mongolian (mn_MN)",
96
+ "Marathi (mr_IN)",
97
+ "Polish (pl_PL)",
98
+ "Pashto (ps_AF)",
99
+ "Portuguese (pt_XX)",
100
+ "Swedish (sv_SE)",
101
+ "Swahili (sw_KE)",
102
+ "Tamil (ta_IN)",
103
+ "Telugu (te_IN)",
104
+ "Thai (th_TH)",
105
+ "Tagalog (tl_XX)",
106
+ "Ukrainian (uk_UA)",
107
+ "Urdu (ur_PK)",
108
+ "Xhosa (xh_ZA)",
109
+ "Galician (gl_ES)",
110
+ "Slovene (sl_SI)",
111
+ ]
112
+
113
+ # Convert the information into a dictionary
114
+ language_dict = {}
115
+ for info in language_info:
116
+ name, code = info.split(" (")
117
+ code = code[:-1]
118
+ language_dict[name] = code
119
+
120
+ # Get the language names for choices in the dropdown
121
+ languages = list(language_dict.keys())
122
+ first_language = languages[0]
123
+ sorted_languages = sorted(languages[1:])
124
+ sorted_languages.insert(0, first_language)
125
+
126
+ default_language = "English"
127
+
128
+ # Prediction responses
129
+ malignant_text = "Malignant. Please consult a doctor for further evaluation."
130
+ benign_text = "Benign. Please consult a doctor for further evaluation."
131
+
132
+
133
+ # In[23]:
134
+
135
+
136
+ # Create instance
137
+ llm = ChatOpenAI(
138
+ openai_api_key=openai_api_key, model_name="gpt-3.5-turbo", temperature=0
139
+ )
140
+
141
+
142
+ # In[24]:
143
+
144
+
145
+ # Method to get system and human messages for ChatOpenAI - Predictions
146
+ def get_prediction_messages(prediction_text):
147
+ # Create a HumanMessage object
148
+ human_message = HumanMessage(content=f"skin lesion that appears {prediction_text}")
149
+
150
+ # Get the system message
151
+ system_message = SystemMessage(
152
+ content="You are a medical professional chatting with a patient. You want to provide helpful information and give a preliminary assessment."
153
+ )
154
+
155
+ # Return the system message
156
+ return [system_message, human_message]
157
+
158
+
159
+ # In[25]:
160
+
161
+
162
+ # Method to get system and human messages for ChatOpenAI - Help
163
+ def get_chat_messages(chat_prompt):
164
+ # Create a HumanMessage object
165
+ human_message = HumanMessage(content=chat_prompt)
166
+
167
+ # Get the system message
168
+ system_message = SystemMessage(
169
+ content="You are a medical professional chatting with a patient. You want to provide helpful information."
170
+ )
171
+ # Return the system message
172
+ return [system_message, human_message]
173
+
174
+
175
+ # In[26]:
176
+
177
+
178
+ # Method to predict the image
179
+ def predict_image(language, img):
180
+ try:
181
+ try:
182
+ # Process the image
183
+ img = img.resize((224, 224))
184
+ img_array = np.array(img) / 255.0
185
+ img_array = np.expand_dims(img_array, axis=0)
186
+ except Exception as e:
187
+ print(f"Error: {e}")
188
+ return "There was an error processing the image. Please try again."
189
+
190
+ # Get prediction from model
191
+ prediction = adam_5.predict(img_array)
192
+ text_prediction = "Malignant" if prediction[0][0] > 0.5 else "Benign"
193
+
194
+ try:
195
+ # Get the system and human messages
196
+ messages = get_prediction_messages(text_prediction)
197
+
198
+ # Get the response from ChatOpenAI
199
+ result = llm(messages)
200
+
201
+ # Get the text prediction
202
+ text_prediction = (
203
+ f"Prediction: {text_prediction} Explanation: {result.content}"
204
+ )
205
+
206
+ except Exception as e:
207
+ print(f"Error: {e}")
208
+ print(f"Prediction: {text_prediction}")
209
+ text_prediction = (
210
+ malignant_text if text_prediction == "Malignant" else benign_text
211
+ )
212
+
213
+ # Get selected language code
214
+ selected_code = language_dict[language]
215
+
216
+ # Check if the target and source languages are the same
217
+ if selected_code == "en_XX":
218
+ return (
219
+ text_prediction,
220
+ gr.update(visible=False),
221
+ gr.update(visible=True),
222
+ gr.update(visible=True),
223
+ gr.update(visible=True),
224
+ )
225
+
226
+ try:
227
+ # Encode, generate tokens, decode the prediction
228
+ encoded_text = tokenizer(text_prediction, return_tensors="pt")
229
+ generated_tokens = model.generate(
230
+ **encoded_text,
231
+ forced_bos_token_id=tokenizer.lang_code_to_id[selected_code],
232
+ )
233
+ result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
234
+
235
+ # Return the result
236
+ return (
237
+ result[0],
238
+ gr.update(visible=False),
239
+ gr.update(visible=True),
240
+ gr.update(visible=True),
241
+ gr.update(visible=True),
242
+ )
243
+ except Exception as e:
244
+ print(f"Error: {e}")
245
+ return (
246
+ f"""There was an error processing the translation.
247
+ In English:
248
+ {text_prediction}
249
+ """,
250
+ gr.update(visible=False),
251
+ gr.update(visible=True),
252
+ gr.update(visible=True),
253
+ gr.update(visible=True),
254
+ )
255
+
256
+ except Exception as e:
257
+ print(f"Error: {e}")
258
+ return (
259
+ "There was an error processing the request. Please try again.",
260
+ gr.update(visible=True),
261
+ gr.update(visible=False),
262
+ gr.update(visible=False),
263
+ gr.update(visible=False),
264
+ )
265
+
266
+
267
+ # In[27]:
268
+
269
+
270
+ # Method for on submit
271
+ def on_submit(language, img):
272
+ print(f"Language: {language}")
273
+ if language is None or len(language) == 0:
274
+ language = default_language
275
+ if img is None:
276
+ return (
277
+ "No image uploaded. Please try again.",
278
+ gr.update(visible=True),
279
+ gr.update(visible=False),
280
+ gr.update(visible=False),
281
+ gr.update(visible=False),
282
+ )
283
+ return predict_image(language, img)
284
+
285
+
286
+ # In[28]:
287
+
288
+
289
+ # Method for on clear
290
+ def on_clear():
291
+ return (
292
+ gr.update(),
293
+ gr.update(),
294
+ gr.update(),
295
+ gr.update(visible=True),
296
+ gr.update(value=None, visible=False),
297
+ gr.update(value=None, visible=False),
298
+ gr.update(visible=False),
299
+ )
300
+
301
+
302
+ # In[29]:
303
+
304
+
305
+ # Method for on chat
306
+ def on_chat(language, chat_prompt):
307
+ try:
308
+ # Get the system and human messages
309
+ messages = get_chat_messages(chat_prompt)
310
+ # Get the response from ChatOpenAI
311
+ result = llm(messages)
312
+ # Get the text prediction
313
+ chat_response = result.content
314
+
315
+ except Exception as e:
316
+ print(f"Error: {e}")
317
+ return gr.update(
318
+ value="There was an error processing your question. Please try again.",
319
+ visible=True,
320
+ ), gr.update(visible=False)
321
+
322
+ # Get selected language code
323
+ if language is None or len(language) == 0:
324
+ language = default_language
325
+ selected_code = language_dict[language]
326
+ # Check if the target and source languages are the same
327
+ if selected_code == "en_XX":
328
+ return gr.update(value=chat_response, visible=True), gr.update(visible=False)
329
+
330
+ try:
331
+ # Encode, generate tokens, decode the prediction
332
+ encoded_text = tokenizer(chat_response, return_tensors="pt")
333
+ generated_tokens = model.generate(
334
+ **encoded_text, forced_bos_token_id=tokenizer.lang_code_to_id[selected_code]
335
+ )
336
+ result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
337
+
338
+ # Return the result
339
+ return gr.update(value=result[0], visible=True), gr.update(visible=False)
340
+ except Exception as e:
341
+ print(f"Error: {e}")
342
+ return (
343
+ gr.update(
344
+ value=f"""There was an error processing the translation.
345
+ In English:
346
+ {chat_response}
347
+ """,
348
+ visible=True,
349
+ ),
350
+ gr.update(visible=False),
351
+ )
352
+
353
+
354
+ # In[30]:
355
+
356
+
357
+ # Gradio app
358
+
359
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="green")) as demo:
360
+ intro = gr.Markdown(
361
+ """
362
+ # Welcome to Skin Lesion Image Classifier!
363
+ Select prediction language and upload image to start.
364
+ """
365
+ )
366
+ language = gr.Dropdown(
367
+ label="Response Language - Default English", choices=sorted_languages
368
+ )
369
+ img = gr.Image(image_mode="RGB", type="pil")
370
+ output = gr.Textbox(label="Results", show_copy_button=True)
371
+ chat_prompt = gr.Textbox(
372
+ label="Do you have a question about the results or skin cancer?",
373
+ placeholder="Enter your question here...",
374
+ visible=False,
375
+ )
376
+ chat_response = gr.Textbox(
377
+ label="Chat Response", visible=False, show_copy_button=True
378
+ )
379
+ submit_btn = gr.Button("Submit", variant="primary", visible=True)
380
+ chat_btn = gr.Button("Submit Question", variant="primary", visible=False)
381
+ submit_btn.click(
382
+ fn=on_submit,
383
+ inputs=[language, img],
384
+ outputs=[output, submit_btn, chat_prompt, chat_btn, chat_response],
385
+ )
386
+ chat_btn.click(
387
+ fn=on_chat, inputs=[language, chat_prompt], outputs=[chat_response, chat_btn]
388
+ )
389
+ clear_btn = gr.ClearButton(
390
+ components=[language, img, output, chat_response], variant="stop"
391
+ )
392
+ clear_btn.click(
393
+ fn=on_clear,
394
+ outputs=[
395
+ language,
396
+ img,
397
+ output,
398
+ submit_btn,
399
+ chat_prompt,
400
+ chat_response,
401
+ chat_btn,
402
+ ],
403
+ )
404
+
405
+ if __name__ == "__main__":
406
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.4.1
2
+ langchain==0.1.17
3
+ langchain-community==0.0.36
4
+ langchain-core==0.1.50
5
+ langchain-openai==0.1.6
6
+ langchain-text-splitters==0.0.1
7
+ python-dotenv==1.0.0
8
+ tensorflow==2.12.0
9
+ transformers==4.40.1