MRI-Brain-Tumor / app.py
TharunSiva's picture
util files
2865872 verified
import gradio as gr
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow.keras.backend as K
from keras.preprocessing import image
from ResUNet import *
from eff import *
from vit import *
from eff_b3 import *
# Define the image transformation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
])
examples1 = [
["examples/Classification/0.jpg", "EfficientNet-B3"],
["examples/Classification/3.jpg", "EfficientNet-B3"],
["examples/Classification/1.jpg", "EfficientNet-V2"],
["examples/Classification/4.jpg", "EfficientNet-V2"],
["examples/Classification/2.jpg", "ViT"],
["examples/Classification/5.jpg", "ViT"],
# f"examples/Classification/{i}.jpg" for i in range(6)
]
# def classification(image):
# input_tensor = transform(image).unsqueeze(0).to(CFG.DEVICE) # Add batch dimension
# input_batch = input_tensor
# # Perform inference
# with torch.no_grad():
# output1 = efficientnet_model(input_batch).to(CFG.DEVICE)
# output2 = vit_model(input_batch).to(CFG.DEVICE)
# b3_img = cv2.resize(image, (256, 256))
# b3_img = np.reshape(b3_img, (1, 256, 256, 3))
# output3 = b3_model.predict(b3_img)
# # You can now use the 'output' tensor as needed (e.g., get predictions)
# # print(output)
# res1 = torch.softmax(output1, dim=1)
# res2 = torch.softmax(output2, dim=1)
# res3 = tf.nn.softmax(output3)
# probs1 = {class_names[i]: float(res1[0][i]) for i in range(len(class_names))}
# probs2 = {class_names[i]: float(res2[0][i]) for i in range(len(class_names))}
# probs3 = {class_names[i]: float(res3[0][i]) for i in range(len(class_names))}
# return probs3, probs2, probs1
# def classification(image, model="EfficientNet-B3"):
# input_tensor = transform(image).unsqueeze(0).to(CFG.DEVICE) # Add batch dimension
# input_batch = input_tensor
# if(model == "EfficientNet-B3"):
# b3_img = cv2.resize(image, (256, 256))
# b3_img = np.reshape(b3_img, (1, 256, 256, 3))
# output3 = b3_model.predict(b3_img)
# res3 = tf.nn.softmax(output3)
# probs3 = {class_names[i]: float(res3[0][i]) for i in range(len(class_names))}
# return probs3
# elif(model == "EfficientNet-V2"):
# with torch.no_grad():
# output1 = efficientnet_model(input_batch).to(CFG.DEVICE)
# res1 = torch.softmax(output1, dim=1)
# probs1 = {class_names[i]: float(res1[0][i]) for i in range(len(class_names))}
# return probs1
# else:
# with torch.no_grad():
# output2 = vit_model(input_batch).to(CFG.DEVICE)
# res2 = torch.softmax(output2, dim=1)
# probs2 = {class_names[i]: float(res2[0][i]) for i in range(len(class_names))}
# return probs2
def classification(image, model="EfficientNet-B3"):
input_tensor = transform(image).unsqueeze(0).to(CFG.DEVICE) # Add batch dimension
input_batch = input_tensor
if(model=="ViT"):
with torch.no_grad():
output = vit_model(input_batch).to(CFG.DEVICE)
res = torch.softmax(output, dim=1)
vit_probs = {class_names[i]: float(res[0][i]) for i in range(len(class_names))}
return vit_probs
elif(model=="EfficientNet-V2"):
with torch.no_grad():
output = efficientnet_model(input_batch).to(CFG.DEVICE)
res = torch.softmax(output, dim=1)
v2_probs = {class_names[i]: float(res[0][i]) for i in range(len(class_names))}
return v2_probs
else:
b3_img = cv2.resize(image, (256, 256))
b3_img = np.reshape(b3_img, (1, 256, 256, 3))
output3 = b3_model.predict(b3_img)
res3 = tf.nn.softmax(output3)
b3_probs = {class_names[i]: float(res3[0][i]) for i in range(len(class_names))}
return b3_probs
classify = gr.Interface(
fn=classification,
inputs=[
gr.Image(label="Image"),
gr.Radio(["EfficientNet-B3", "EfficientNet-V2", "ViT"], value="EfficientNet-B3")
],
outputs=[
gr.Label(num_top_classes = 3, label = "Result"),
# gr.Label(num_top_classes = 3, label = "EfficientNet-V2"),
# gr.Label(num_top_classes = 3, label = "ViT"),
],
examples=examples1,
cache_examples=True
)
# ---------------------------------------------------------
seg_model = load_model()
seg_model.load_weights("ResUNet-segModel-weights.hdf5")
examples2 = [
f"examples/ResUNet/{i}.jpg" for i in range(5)
]
def detection(img):
org_img = img
img = img *1./255.
#reshaping
img = cv2.resize(img, (256,256))
# converting img into array
img = np.array(img, dtype=np.float64)
#reshaping the image from 256,256,3 to 1,256,256,3
img = np.reshape(img, (1,256,256,3))
#Creating a empty array of shape 1,256,256,1
X = np.empty((1,256,256,3))
# standardising the image
img -= img.mean()
img /= img.std()
#converting the shape of image from 256,256,3 to 1,256,256,3
X[0,] = img
#make prediction of mask
predict = seg_model.predict(X)
pred = np.array(predict[0]).squeeze().round()
img_ = cv2.resize(org_img, (256,256))
img_ = cv2.cvtColor(img_, cv2.COLOR_BGR2RGB)
img_[pred==1] = (0,255,150)
plt.imshow(img_)
plt.axis("off")
image_path = "plot.png"
plt.savefig(image_path)
return gr.update(value=image_path, visible=True)
detect = gr.Interface(
fn=detection,
inputs=[
gr.Image(label="Image")
],
outputs=[
gr.Image(label="Output")
],
examples=examples2,
cache_examples=True
)
# ##########################################
def data_viewer(label="Pituitary", count=10):
results = []
if(label == "Segmentation"):
for i in range((count//2)+1):
results.append(f"Images/{label}/original_image_{i}.png")
results.append(f"Images/{label}/image_with_mask_{i}.png")
else:
for i in range(count):
results.append(f"Images/{label}/{i}.jpg")
return results
view_data = gr.Interface(
fn = data_viewer,
inputs = [
gr.Dropdown(
["Glioma", "Meningioma", "Pituitary", "Segmentation"], label="Category"
),
gr.Slider(0, 12, value=4, step=2)
],
outputs = [
gr.Gallery(columns=2),
]
)
# ##########################
from huggingface_hub import InferenceClient
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(
prompt, history, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,
):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
mychatbot = gr.Chatbot(
avatar_images=["Chatbot/user.png", "Chatbot/botm.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,)
chatbot = gr.ChatInterface(
fn=generate,
chatbot=mychatbot,
examples=[
"What is Brain Tumor and its types?",
"What is a tumor's grade? What does this mean?",
"What are some of the treatment options for Brain Tumor?",
"What causes brain tumors?",
"If I have a brain tumor, can I pass it on to my children?"
],
)
demo = gr.TabbedInterface([classify, detect, view_data, chatbot], ["Classification", "Detection", "Visualization", "ChatBot"])
demo.launch()