Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
# In[17]: | |
import pickle | |
from PIL import Image | |
import numpy as np | |
import gradio as gr | |
from pathlib import Path | |
from transformers import pipeline | |
from tensorflow.keras.models import load_model | |
import tensorflow as tf | |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
from dotenv import load_dotenv | |
import openai | |
import os | |
from langchain.schema import HumanMessage, SystemMessage | |
from langchain_openai import ChatOpenAI | |
# Set the model's file path | |
file_path = Path("models/model_adam_5.h5") | |
# Load the model to a new object | |
adam_5 = tf.keras.models.load_model(file_path) | |
# Load env variables | |
load_dotenv() | |
# Add your OpenAI API key here | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
print(f"OpenAI API Key Loaded: {openai_api_key is not None}") | |
# Load the model and tokenizer for translation | |
model = MBartForConditionalGeneration.from_pretrained( | |
"facebook/mbart-large-50-many-to-many-mmt" | |
) | |
tokenizer = MBart50TokenizerFast.from_pretrained( | |
"facebook/mbart-large-50-many-to-many-mmt" | |
) | |
# Set source language | |
tokenizer.src_lang = "en_XX" | |
# Constants | |
# Language information MBart | |
language_info = [ | |
"English (en_XX)", | |
"Arabic (ar_AR)", | |
"Czech (cs_CZ)", | |
"German (de_DE)", | |
"Spanish (es_XX)", | |
"Estonian (et_EE)", | |
"Finnish (fi_FI)", | |
"French (fr_XX)", | |
"Gujarati (gu_IN)", | |
"Hindi (hi_IN)", | |
"Italian (it_IT)", | |
"Japanese (ja_XX)", | |
"Kazakh (kk_KZ)", | |
"Korean (ko_KR)", | |
"Lithuanian (lt_LT)", | |
"Latvian (lv_LV)", | |
"Burmese (my_MM)", | |
"Nepali (ne_NP)", | |
"Dutch (nl_XX)", | |
"Romanian (ro_RO)", | |
"Russian (ru_RU)", | |
"Sinhala (si_LK)", | |
"Turkish (tr_TR)", | |
"Vietnamese (vi_VN)", | |
"Chinese (zh_CN)", | |
"Afrikaans (af_ZA)", | |
"Azerbaijani (az_AZ)", | |
"Bengali (bn_IN)", | |
"Persian (fa_IR)", | |
"Hebrew (he_IL)", | |
"Croatian (hr_HR)", | |
"Indonesian (id_ID)", | |
"Georgian (ka_GE)", | |
"Khmer (km_KH)", | |
"Macedonian (mk_MK)", | |
"Malayalam (ml_IN)", | |
"Mongolian (mn_MN)", | |
"Marathi (mr_IN)", | |
"Polish (pl_PL)", | |
"Pashto (ps_AF)", | |
"Portuguese (pt_XX)", | |
"Swedish (sv_SE)", | |
"Swahili (sw_KE)", | |
"Tamil (ta_IN)", | |
"Telugu (te_IN)", | |
"Thai (th_TH)", | |
"Tagalog (tl_XX)", | |
"Ukrainian (uk_UA)", | |
"Urdu (ur_PK)", | |
"Xhosa (xh_ZA)", | |
"Galician (gl_ES)", | |
"Slovene (sl_SI)", | |
] | |
# Convert the information into a dictionary | |
language_dict = {} | |
for info in language_info: | |
name, code = info.split(" (") | |
code = code[:-1] | |
language_dict[name] = code | |
# Get the language names for choices in the dropdown | |
languages = list(language_dict.keys()) | |
first_language = languages[0] | |
sorted_languages = sorted(languages[1:]) | |
sorted_languages.insert(0, first_language) | |
default_language = "English" | |
# Prediction responses | |
malignant_text = "Malignant. Please consult a doctor for further evaluation." | |
benign_text = "Benign. Please consult a doctor for further evaluation." | |
# Create instance | |
llm = ChatOpenAI( | |
openai_api_key=openai_api_key, model_name="gpt-3.5-turbo", temperature=0 | |
) | |
# Method to get system and human messages for ChatOpenAI - Predictions | |
def get_prediction_messages(prediction_text): | |
# Create a HumanMessage object | |
human_message = HumanMessage(content=f"skin lesion that appears {prediction_text}") | |
# Get the system message | |
system_message = SystemMessage( | |
content="You are a medical professional chatting with a patient. You want to provide helpful information and give a preliminary assessment." | |
) | |
# Return the system message | |
return [system_message, human_message] | |
# Method to get system and human messages for ChatOpenAI - Help | |
def get_chat_messages(chat_prompt): | |
# Create a HumanMessage object | |
human_message = HumanMessage(content=chat_prompt) | |
# Get the system message | |
system_message = SystemMessage( | |
content="You are a medical professional chatting with a patient. You want to provide helpful information." | |
) | |
# Return the system message | |
return [system_message, human_message] | |
# Method to predict the image | |
def predict_image(language, img): | |
try: | |
try: | |
# Process the image | |
img = img.resize((224, 224)) | |
img_array = np.array(img) / 255.0 | |
img_array = np.expand_dims(img_array, axis=0) | |
except Exception as e: | |
print(f"Error: {e}") | |
return "There was an error processing the image. Please try again." | |
# Get prediction from model | |
prediction = adam_5.predict(img_array) | |
text_prediction = "Malignant" if prediction[0][0] > 0.5 else "Benign" | |
try: | |
# Get the system and human messages | |
messages = get_prediction_messages(text_prediction) | |
# Get the response from ChatOpenAI | |
result = llm(messages) | |
# Get the text prediction | |
text_prediction = ( | |
f"Prediction: {text_prediction} Explanation: {result.content}" | |
) | |
except Exception as e: | |
print(f"Error: {e}") | |
print(f"Prediction: {text_prediction}") | |
text_prediction = ( | |
malignant_text if text_prediction == "Malignant" else benign_text | |
) | |
# Get selected language code | |
selected_code = language_dict[language] | |
# Check if the target and source languages are the same | |
if selected_code == "en_XX": | |
return ( | |
text_prediction, | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
) | |
try: | |
# Encode, generate tokens, decode the prediction | |
encoded_text = tokenizer(text_prediction, return_tensors="pt") | |
generated_tokens = model.generate( | |
**encoded_text, | |
forced_bos_token_id=tokenizer.lang_code_to_id[selected_code], | |
) | |
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
# Return the result | |
return ( | |
result[0], | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
) | |
except Exception as e: | |
print(f"Error: {e}") | |
return ( | |
f"""There was an error processing the translation. | |
In English: | |
{text_prediction} | |
""", | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
gr.update(visible=True), | |
) | |
except Exception as e: | |
print(f"Error: {e}") | |
return ( | |
"There was an error processing the request. Please try again.", | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
) | |
# Method for on submit | |
def on_submit(language, img): | |
print(f"Language: {language}") | |
if language is None or len(language) == 0: | |
language = default_language | |
if img is None: | |
return ( | |
"No image uploaded. Please try again.", | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
) | |
return predict_image(language, img) | |
# Method for on clear | |
def on_clear(): | |
return ( | |
gr.update(), | |
gr.update(), | |
gr.update(), | |
gr.update(visible=True), | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False), | |
gr.update(visible=False), | |
) | |
# Method for on chat | |
def on_chat(language, chat_prompt): | |
try: | |
# Get the system and human messages | |
messages = get_chat_messages(chat_prompt) | |
# Get the response from ChatOpenAI | |
result = llm(messages) | |
# Get the text prediction | |
chat_response = result.content | |
except Exception as e: | |
print(f"Error: {e}") | |
return gr.update( | |
value="There was an error processing your question. Please try again.", | |
visible=True, | |
), gr.update(visible=False) | |
# Get selected language code | |
if language is None or len(language) == 0: | |
language = default_language | |
selected_code = language_dict[language] | |
# Check if the target and source languages are the same | |
if selected_code == "en_XX": | |
return gr.update(value=chat_response, visible=True), gr.update(visible=False) | |
try: | |
# Encode, generate tokens, decode the prediction | |
encoded_text = tokenizer(chat_response, return_tensors="pt") | |
generated_tokens = model.generate( | |
**encoded_text, forced_bos_token_id=tokenizer.lang_code_to_id[selected_code] | |
) | |
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
# Return the result | |
return gr.update(value=result[0], visible=True), gr.update(visible=False) | |
except Exception as e: | |
print(f"Error: {e}") | |
return ( | |
gr.update( | |
value=f"""There was an error processing the translation. | |
In English: | |
{chat_response} | |
""", | |
visible=True, | |
), | |
gr.update(visible=False), | |
) | |
# Gradio app | |
with gr.Blocks(theme=gr.themes.Default(primary_hue="green")) as demo: | |
intro = gr.Markdown( | |
""" | |
# Welcome to Skin Lesion Image Classifier! | |
Select prediction language and upload image to start. | |
""" | |
) | |
language = gr.Dropdown( | |
label="Response Language - Default English", choices=sorted_languages | |
) | |
img = gr.Image(image_mode="RGB", type="pil") | |
output = gr.Textbox(label="Results", show_copy_button=True) | |
chat_prompt = gr.Textbox( | |
label="Do you have a question about the results or skin cancer?", | |
placeholder="Enter your question here...", | |
visible=False, | |
) | |
chat_response = gr.Textbox( | |
label="Chat Response", visible=False, show_copy_button=True | |
) | |
submit_btn = gr.Button("Submit", variant="primary", visible=True) | |
chat_btn = gr.Button("Submit Question", variant="primary", visible=False) | |
submit_btn.click( | |
fn=on_submit, | |
inputs=[language, img], | |
outputs=[output, submit_btn, chat_prompt, chat_btn, chat_response], | |
) | |
chat_btn.click( | |
fn=on_chat, inputs=[language, chat_prompt], outputs=[chat_response, chat_btn] | |
) | |
clear_btn = gr.ClearButton( | |
components=[language, img, output, chat_response], variant="stop" | |
) | |
clear_btn.click( | |
fn=on_clear, | |
outputs=[ | |
language, | |
img, | |
output, | |
submit_btn, | |
chat_prompt, | |
chat_response, | |
chat_btn, | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |