import torchvision.transforms as transforms | |
import matplotlib.pyplot as plt | |
import torch | |
import numpy as np | |
from PIL import Image | |
from collections import OrderedDict | |
from Transformer_Explainability.modules.layers_ours import * | |
from Transformer_Explainability.baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP | |
from Transformer_Explainability.baselines.ViT.ViT_explanation_generator import LRP | |
from Transformer_Explainability.XAI import * | |
from io import BytesIO | |
from PIL import Image | |
import streamlit as st | |
import pandas as pd | |
from streamlit_option_menu import option_menu | |
# import Pages as pg | |
# mean = [0.485, 0.456, 0.406] # our images have three channels | |
# std = [0.229, 0.224, 0.225] | |
# normalize = transforms.Normalize(mean=mean, std=std) | |
# transform = transforms.Compose([ | |
# transforms.Resize((224,224)), | |
# transforms.ToTensor(), | |
# normalize, | |
# ]) | |
# resizing = transforms.Resize((224,224)) | |
CLS2IDX= {0: "Benign", | |
1: "Ductal carcinoma", | |
2: "Lobular carcinoma", | |
3: "Mucinous carcinoma", | |
4: "Papillary carcinoma"} | |
st.set_page_config( | |
page_title="Breast Cancer Diagnosis", | |
# page_icon="👨⚕️", | |
layout="wide", | |
initial_sidebar_state="expanded", #collapsed | |
) | |
# "st.session_state :", st.session_state | |
if "loop" not in st.session_state: | |
st.session_state["loop"] = 1 | |
def use_model(upl, mag): | |
num_classes=5 | |
model = vit_LRP(pretrained=True) | |
num_ftrs = model.head.in_features | |
model.head = Linear(num_ftrs, num_classes) | |
state_dict = torch.load("./Model/V" + mag + ".pth", map_location=torch.device('cpu')) | |
model.load_state_dict(state_dict) | |
model.eval() | |
ag = LRP(model) | |
model_pred=[] | |
with st.container(): | |
st.subheader("Result") | |
for file in upl: | |
bytes_data = file.read() | |
img = Image.open(BytesIO(bytes_data)) | |
inputs = tf(img) | |
outputs = model(inputs.unsqueeze(0)) | |
prob = torch.softmax(outputs, dim=1) | |
class_indices = outputs.data.topk(5, dim=1)[1][0].tolist() | |
model_prop = [] | |
for cls_idx in class_indices: | |
model_prop.append([CLS2IDX[cls_idx], "{:.2f}".format(prob[0, cls_idx]*100)]) | |
img_col1, img_col2, txt_col = st.columns((1, 1, 2)) | |
with img_col1: | |
st.image( | |
transforms.Resize((224,224))(img), | |
output_format="auto", | |
caption="original image" | |
) | |
with img_col2: | |
st.image( | |
generate_visualization(inputs, ag), | |
caption="explainability" | |
) | |
with txt_col: | |
df = pd.DataFrame( | |
model_prop, | |
columns=["Class", "Probability"] | |
) | |
st.dataframe( | |
df, | |
width=600 | |
) | |
rslt = CLS2IDX[class_indices[0]] | |
if rslt == "Benign": | |
tp = "-" | |
else: | |
tp = rslt | |
rslt = "Malignant" | |
model_pred.append([file.name, rslt, tp]) | |
with st.container(): | |
st.subheader("Conclusion") | |
df2 = pd.DataFrame( | |
model_pred, columns=["Name", "Result", "Type"] | |
) | |
with st.container(): | |
st.dataframe( | |
df2, | |
width=800 | |
) | |
# --Header-- | |
with st.container(): | |
st.title("Breast cancer diagnosis by transformer model") | |
st.subheader("Hi :wave:") | |
st.write("This website will help you to diagnose breast cancer images that collected by surgical open biopsy method(SOB)") | |
st.write("The data that you will input must same to BREAKHIS dataset [Learn more >](https://google.com)") | |
# --Sidebar-- | |
with st.sidebar: | |
# --Inputs-- | |
with st.container(): | |
st.subheader("Inputs") | |
maglist = ["--Select--", "40X", "100X", "200X", "400X"] | |
magnification = st.sidebar.selectbox("Select the Magnification", maglist) | |
with st.container(): | |
uploaded = st.file_uploader( | |
"Choose images to diagnose", type=["jpg", "jpeg", "png"], accept_multiple_files = True | |
) | |
# with st.container(): | |
# diagnose = st.button("Diagnose") | |
with st.container(): | |
if uploaded != [] and magnification != "--Select--": | |
use_model(uploaded, magnification) | |
st.session_state["loop"]+=1 | |
# "st.session_state obj :", st.session_state | |
# pass | |
# selected = option_menu( | |
# menu_title=None, # required | |
# options=["Home", "How to use", "Contact"] | |
# ) | |
# if selected == "Home": | |
# inputs() | |
# elif selected == "How to use": | |
# pg.howtouse() | |
# elif selected == "Contact": | |
# pg.contact() | |
# st.set_page_config( | |
# page_title="Breast Cancer Diagnosis", | |
# page_icon="👨⚕️", | |
# layout="wide", | |
# initial_sidebar_state="expanded", #collapsed | |
# ) | |
# st.title("Breast Cancer Diagnosis BY Transformer Model") | |
# st.sidebar.subheader("Input") | |
# "st.session_state obj :", st.session_state | |
# if 'boolean' not in st.session_state: | |
# st.session_state['boolean'] = False | |
# "st.session_state obj :", st.session_state | |
# models_list = ["--Select--", "40X", "100X", "200X", "400X"] | |
# magnitude = st.sidebar.selectbox("Select the Magnification", models_list) | |
# uploaded_file = st.sidebar.file_uploader( | |
# "Choose images to diagnose", type=["jpg", "jpeg", "png"], accept_multiple_files = True | |
# ) | |
# diagnosis = st.sidebar.button("Diagnose") | |
# if diagnosis: | |
# st.session_state['boolean'] = True | |
# num_classes = 5 | |
# model = vit_LRP(pretrained=True) | |
# num_ftrs = model.head.in_features | |
# model.head = Linear(num_ftrs, num_classes) | |
# if uploaded_file != [] and st.session_state['boolean'] and magnitude != "--Select--": | |
# state_dict = torch.load("./Model/V" + magnitude + ".pth", map_location=torch.device('cpu')) | |
# model.load_state_dict(state_dict) | |
# model.eval() | |
# ag = LRP(model) | |
# predictions = [] | |
# xai = [] | |
# for each_file in uploaded_file: | |
# bytes_data = each_file.read() | |
# img = Image.open(BytesIO(bytes_data)) | |
# inputs = transform(img) | |
# outputs = model(inputs.unsqueeze(0)) | |
# model_predicted = CLS2IDX[print_top_classes(outputs)[0]] | |
# if model_predicted == "benign": | |
# type = '-' | |
# else: | |
# type = model_predicted | |
# model_predicted = "malignant" | |
# xai.append(generate_visualization(inputs, ag)) | |
# predictions.append([img, each_file.name, model_predicted, type]) | |
# #st.write(each_file) | |
# #st.write(each_file.name) | |
# # st.image(img) | |
# #st.write(model_predicted) | |
# def showim(): | |
# df = pd.DataFrame( | |
# predictions, columns=["Image", "Name", "Result", "Type"] | |
# ) | |
# st.dataframe( | |
# df.iloc[:,1:], width=1000 | |
# ) | |
# idx = st.select_slider("Select the index of picture", df.index, on_change=showim) | |
# st.write(idx) | |
# # st.image(predictions[index][0]) | |
# st.image(xai[idx]) | |
# showim() | |
# print(uploaded_file, diagnosis, magnitude) |