import torchvision.transforms as transforms import matplotlib.pyplot as plt import torch import numpy as np from PIL import Image from collections import OrderedDict from Transformer_Explainability.modules.layers_ours import * from Transformer_Explainability.baselines.ViT.ViT_LRP import vit_base_patch16_224 as vit_LRP from Transformer_Explainability.baselines.ViT.ViT_explanation_generator import LRP from Transformer_Explainability.XAI import * from io import BytesIO from PIL import Image import streamlit as st import pandas as pd from streamlit_option_menu import option_menu # import Pages as pg # mean = [0.485, 0.456, 0.406] # our images have three channels # std = [0.229, 0.224, 0.225] # normalize = transforms.Normalize(mean=mean, std=std) # transform = transforms.Compose([ # transforms.Resize((224,224)), # transforms.ToTensor(), # normalize, # ]) # resizing = transforms.Resize((224,224)) CLS2IDX= {0: "Benign", 1: "Ductal carcinoma", 2: "Lobular carcinoma", 3: "Mucinous carcinoma", 4: "Papillary carcinoma"} st.set_page_config( page_title="Breast Cancer Diagnosis", # page_icon="👨‍⚕️", layout="wide", initial_sidebar_state="expanded", #collapsed ) # "st.session_state :", st.session_state if "loop" not in st.session_state: st.session_state["loop"] = 1 def use_model(upl, mag): num_classes=5 model = vit_LRP(pretrained=True) num_ftrs = model.head.in_features model.head = Linear(num_ftrs, num_classes) state_dict = torch.load("./Model/V" + mag + ".pth", map_location=torch.device('cpu')) model.load_state_dict(state_dict) model.eval() ag = LRP(model) model_pred=[] with st.container(): st.subheader("Result") for file in upl: bytes_data = file.read() img = Image.open(BytesIO(bytes_data)) inputs = tf(img) outputs = model(inputs.unsqueeze(0)) prob = torch.softmax(outputs, dim=1) class_indices = outputs.data.topk(5, dim=1)[1][0].tolist() model_prop = [] for cls_idx in class_indices: model_prop.append([CLS2IDX[cls_idx], "{:.2f}".format(prob[0, cls_idx]*100)]) img_col1, img_col2, txt_col = st.columns((1, 1, 2)) with img_col1: st.image( transforms.Resize((224,224))(img), output_format="auto", caption="original image" ) with img_col2: st.image( generate_visualization(inputs, ag), caption="explainability" ) with txt_col: df = pd.DataFrame( model_prop, columns=["Class", "Probability"] ) st.dataframe( df, width=600 ) rslt = CLS2IDX[class_indices[0]] if rslt == "Benign": tp = "-" else: tp = rslt rslt = "Malignant" model_pred.append([file.name, rslt, tp]) with st.container(): st.subheader("Conclusion") df2 = pd.DataFrame( model_pred, columns=["Name", "Result", "Type"] ) with st.container(): st.dataframe( df2, width=800 ) # --Header-- with st.container(): st.title("Breast cancer diagnosis by transformer model") st.subheader("Hi :wave:") st.write("This website will help you to diagnose breast cancer images that collected by surgical open biopsy method(SOB)") st.write("The data that you will input must same to BREAKHIS dataset [Learn more >](https://docs.google.com/document/d/12NYTJkh2yKdR75XQhD3OwQoYFFwqIK7mkSVbd0-VWok/edit?usp=drivesdk)") # --Sidebar-- with st.sidebar: # --Inputs-- with st.container(): st.subheader("Inputs") maglist = ["--Select--", "40X", "100X", "200X", "400X"] magnification = st.sidebar.selectbox("Select the Magnification", maglist) with st.container(): uploaded = st.file_uploader( "Choose images to diagnose", type=["jpg", "jpeg", "png"], accept_multiple_files = True ) # with st.container(): # diagnose = st.button("Diagnose") with st.container(): if uploaded != [] and magnification != "--Select--": use_model(uploaded, magnification) st.session_state["loop"]+=1 # "st.session_state obj :", st.session_state # pass # selected = option_menu( # menu_title=None, # required # options=["Home", "How to use", "Contact"] # ) # if selected == "Home": # inputs() # elif selected == "How to use": # pg.howtouse() # elif selected == "Contact": # pg.contact() # st.set_page_config( # page_title="Breast Cancer Diagnosis", # page_icon="👨‍⚕️", # layout="wide", # initial_sidebar_state="expanded", #collapsed # ) # st.title("Breast Cancer Diagnosis BY Transformer Model") # st.sidebar.subheader("Input") # "st.session_state obj :", st.session_state # if 'boolean' not in st.session_state: # st.session_state['boolean'] = False # "st.session_state obj :", st.session_state # models_list = ["--Select--", "40X", "100X", "200X", "400X"] # magnitude = st.sidebar.selectbox("Select the Magnification", models_list) # uploaded_file = st.sidebar.file_uploader( # "Choose images to diagnose", type=["jpg", "jpeg", "png"], accept_multiple_files = True # ) # diagnosis = st.sidebar.button("Diagnose") # if diagnosis: # st.session_state['boolean'] = True # num_classes = 5 # model = vit_LRP(pretrained=True) # num_ftrs = model.head.in_features # model.head = Linear(num_ftrs, num_classes) # if uploaded_file != [] and st.session_state['boolean'] and magnitude != "--Select--": # state_dict = torch.load("./Model/V" + magnitude + ".pth", map_location=torch.device('cpu')) # model.load_state_dict(state_dict) # model.eval() # ag = LRP(model) # predictions = [] # xai = [] # for each_file in uploaded_file: # bytes_data = each_file.read() # img = Image.open(BytesIO(bytes_data)) # inputs = transform(img) # outputs = model(inputs.unsqueeze(0)) # model_predicted = CLS2IDX[print_top_classes(outputs)[0]] # if model_predicted == "benign": # type = '-' # else: # type = model_predicted # model_predicted = "malignant" # xai.append(generate_visualization(inputs, ag)) # predictions.append([img, each_file.name, model_predicted, type]) # #st.write(each_file) # #st.write(each_file.name) # # st.image(img) # #st.write(model_predicted) # def showim(): # df = pd.DataFrame( # predictions, columns=["Image", "Name", "Result", "Type"] # ) # st.dataframe( # df.iloc[:,1:], width=1000 # ) # idx = st.select_slider("Select the index of picture", df.index, on_change=showim) # st.write(idx) # # st.image(predictions[index][0]) # st.image(xai[idx]) # showim() # print(uploaded_file, diagnosis, magnitude)