### FRAMEWORKS AND DEPENDENCIES import copy import os import sys from collections import OrderedDict from pathlib import Path import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.cm as mpl_color_map from PIL import Image, ImageFilter from collections import OrderedDict import matplotlib as mpl import torch import torch.nn as nn from torchvision import datasets, models, transforms import torchxrayvision as xrv from pytorch_grad_cam import GradCAM # Other methods available: ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM from pytorch_grad_cam.utils.image import show_cam_on_image from skimage.io import imread import pydicom as dicom import operator import mols2grid import streamlit.components.v1 as components from rdkit import Chem from rdkit.Chem.Descriptors import ExactMolWt from chembl_webresource_client.new_client import new_client import streamlit as st ####UTILS.PY model_names = ['densenet121-res224-mimic_nb', 'densenet121-res224-mimic_ch'] #### FUNCTIONS FOR STREAMLIT ### Cache Drugs (Get Compounds found) @st.cache(allow_output_mutation=True) def getdrugs(name,phase): drug_indication = new_client.drug_indication molecules = new_client.molecule obj = drug_indication.filter(efo_term__icontains=name) appdrugs = molecules.filter(molecule_chembl_id__in=[x['molecule_chembl_id'] for x in obj]) if phase!=[]: temp = None for ph in phase: dftemp = pd.DataFrame.from_dict(appdrugs.filter(max_phase=int(ph))) dftemp["phase"] = int(ph) if isinstance(temp,pd.DataFrame): temp= pd.concat([temp,dftemp],axis=0) else: temp = dftemp df = temp else: df = pd.DataFrame.from_dict(appdrugs) try: df.dropna(subset=["molecule_properties","molecule_structures"],inplace=True) df["smiles"] = df.molecule_structures.apply(lambda x:x["canonical_smiles"]) df["Acceptors"] = df.molecule_properties.apply(lambda x :x["hba"]) df["Donnors"] = df.molecule_properties.apply(lambda x :x["hbd"]) df["mol_weight"] = df.molecule_properties.apply(lambda x :x["mw_freebase"]) df["Logp"] = df.molecule_properties.apply(lambda x :x["cx_logp"]) subs = ["pref_name","smiles","Acceptors","Donnors","mol_weight","Logp"] df.dropna(subset=subs,inplace=True) df["Acceptors"] = df["Acceptors"].astype(int) df["Donnors"] = df["Donnors"].astype(int) df["mol_weight"] = df["mol_weight"].astype(float) df["Logp"] = df["Logp"] .astype(float) return df.loc[:,subs] except: return None ### Title def header(): st.markdown("

Chest Anomaly Identifier

",unsafe_allow_html=True) ### Description st.markdown("""

This is a pocket application that is mainly focused on aiding medical professionals on their diagnostics and treatments for chest anomalies based on chest X-Rays. On this application, users can upload a chest X-Ray image and a deep learning model will output the probability of 14 different anomalies taking place on that image

""",unsafe_allow_html=True) ### Image st.image("doctors.jpg") ### Controllers def controllers2(model_probs): # Select the anomaly to detect st.sidebar.markdown("

Anomaly detection

",unsafe_allow_html=True) option_anomaly = st.sidebar.selectbox('Select Anomaly to detect',['Atelectasis', 'Consolidation', 'Pneumothorax','Edema', 'Effusion', 'Pneumonia', 'Cardiomegaly'],help='Select the anomaly you want to detect') # Filtering anomalies st.sidebar.markdown('''

This controller is used to filter anomaly detection

- N : Select the number of most likely anomalies you want to detect - Threshold : It measures how strict you are with the threshold - Colors : For color intensity of anomaly detection - Obscureness : For darker or lighter colors ''',unsafe_allow_html=True) N = st.sidebar.slider(label="N",min_value=1,max_value=5,value=3,step=1,help="Select the number of most likely anomalies you want to detect") threshold = st.sidebar.slider(label="Threshold",min_value=0.0,max_value=1.0,value=0.3,step=0.1,help="Select the degree of confidence you want to detect. The more is the value the more strict you are in your detection") colors = st.sidebar.slider("Intense Colors",min_value=0.0,max_value=1.0,value=0.6,step=0.1,help="Select the color intensity you want to display at the time on detecting an anomaly. The higuer the value, the more intense the color") obscureness = st.sidebar.slider("Obscureness",min_value=0.0,max_value=1.0,value=0.8,step=0.1,help="Select the obscureness you want your colors have. The higuer the value, the more obscure is the color") # Select the treatment st.sidebar.markdown("

Anomaly Treatment

",unsafe_allow_html=True) option = st.sidebar.selectbox('Select the anomaly for treatment',list(model_probs[model_names[0]].keys()),help='Select the anomaly you want to treat') #### Filtering treatments st.sidebar.markdown("

Compound's filter

",unsafe_allow_html=True) ## Write the compound st.sidebar.markdown('''

This controller sidebar is used to filter the compounds by the following features

- Molecular weight : is the weight of a compound in grame per mol - LogP : it measures how hydrophilic or hydrophobic a compound is - NumDonnors : number of chemical components that are able to deliver electrons to other chemical components - NumAcceptors : number of chemical components that are able to accept electrons to other chemical components - MaxPhase : select the phase in which the compound is stablished ''',unsafe_allow_html=True) weight_cutoff = st.sidebar.slider( label="Molecular weight", min_value=0, max_value=1000, value=500, step=10, help="Look for compounds that have less or equal molecular weight than the value selected" ) logp_cutoff = st.sidebar.slider( label="LogP", min_value=-10, max_value=10, value=5, step=1, help="Look for compounds that have less or equal logp than the value selected" ) NumHDonors_cutoff = st.sidebar.slider( label="NumHDonors", min_value=0, max_value=15, value=5, step=1, help="Look for compounds that have less or equal donors weight than the value selected" ) NumHAcceptors_cutoff = st.sidebar.slider( label="NumHAcceptors", min_value=0, max_value=20, value=10, step=1, help="Look for compounds that have less or equal acceptors weight than the value selected" ) max_phase = st.sidebar.multiselect("Select Phase of the compound", ['1','2', '3', '4'], help=""" - Phase 1 : Phase I of the compound in progress - Phase 2 : Phase II of the compound in progress - Phase 3 : Phase III of the compound in progress - Phase 4 : Approved compound """ ) return option_anomaly,threshold,colors,obscureness,option,weight_cutoff,logp_cutoff,NumHDonors_cutoff,NumHAcceptors_cutoff,max_phase,N ### MODEL.PY def takemodel(models:OrderedDict,cams:OrderedDict,weights="mimic_ch"): """ Define models and cams of each model; tools useful for heatmap Args: models (OrderedDict[xrv.models.DenseNet]): the CNN of the model cams (OrderedDict[GradCam]): Useful tool to make the heatmap weights (str): Name of the pretrained model weights """ models[weights] = xrv.models.DenseNet(weights=weights) models[weights].eval() target_layer = models[weights].features[-2] cams[weights] = GradCAM(models[weights], target_layer, use_cuda=False) return models,cams #### Read the image | Normalize def normalize(sample, maxval): """ Scales images to be roughly [-1024 1024]. Args: image (dicom,jp,png): image maxval (int): maxvalue of the dicom image From torchxrayvision """ if sample.max() > maxval: raise Exception("max image value ({}) higher than expected bound ({}).".format(sample.max(), maxval)) sample = (2 * (sample.astype(np.float32) / maxval) - 1.) * 1024 #sample = sample / np.std(sample) return sample def extensionimages(image_path): """ Read Image of jpg dicom or png if it does not find the image returns skimage.io.imread(imgpath) Args: image_path (str): path of the image """ if (str(image_path).find("jpg")!=-1) or (str(image_path).find("png")!=-1): # sample = Image.open("JPG_test/0c4eb1e1-b801903c-bcebe8a4-3da9cd3c-3b94a27c.jpg") sample = Image.open(image_path) return np.array(sample) if str(image_path).find("dcm")!=-1: img = dicom.dcmread(image_path).pixel_array return img else: return imread(image_path) def read_image(img, tr=None,visualize=True): """ Scales images to be roughly [-1024 1024]. Args: image_path (str): path of the image From torchxrayvision """ # img = extensionimages(image_path) ### If black image has 3 dim get just one channel try: img = img[:, :, 0] ### Otherwise we take 2 channels except IndexError: pass # Another option will be equalizing the image # img = cv2.equalizeHist(img.astype(np.uint8)) img = ((img-img.min())/(img.max()-img.min())*255) ### Normalize to values -1024 1024 img = normalize(img, 255) # print(img.min(),img.max()) # Add color channel img = img[None, :, :] if tr is not None: img = tr(img) else: raise Exception("You should pass a transformer to downsample the images") return img #### Applly colormap on image def apply_colormap_on_image(org_im, activation, colormap_name, threshold=0.3,alpha=0.6): """ Apply heatmap on image Args: org_img (PIL img): Original image (224x224) activation_map (numpy arr): Activation map (grayscale) 0-255 (224x224) colormap_name (str): Name of the colormap (colormap_name) threshold (float): threshold at which to overlay heatmap (threshold that anomaly must surpass in terms of probability) alpha (float): adjust the intense in which the model predicts Original source: https://github.com/utkuozbulak/pytorch-cnn-visualizations Added thresholding to activations. """ ### Grayscale_cam grayscale_cam = copy.deepcopy(activation) # Get colormap just color type color_map = mpl_color_map.get_cmap(colormap_name) # Like map the activation function to the color map no_trans_heatmap = color_map(activation) ### Not_trans_heatmap output (224x224x4 channels) (HSV-alpha channels) ### H --> channel 0 H --> channel 1 H --> channel 2 alpha --> channel 3 # Change alpha channel in colormap to make sure original image is displayed deepcopy alpha_channel = 3 heatmap = copy.copy(no_trans_heatmap) heatmap[:, :, alpha_channel] = alpha # set to fully transparent if there is a very low activation (if the activation map is lower than the threshold) idx = (grayscale_cam <= threshold) # convert to a 3d index the shape of the image (expand the image by arrays) # Input shape 224x244 --- Output Shape 224x224x1 ignore_idx = np.expand_dims(np.zeros(grayscale_cam.shape, dtype=bool), 2) ### Idx is the four fimenation of the heatmap concatenate 224x224x3 with 224x224x1 ---> 224x224x4 idx = np.concatenate([ignore_idx]*3 + [np.expand_dims(idx, 2)], axis=2) heatmap[idx] = 0 ### Inputs 224x224x4 ### Scale to a 255 integer and map to PIL image heatmap = Image.fromarray((heatmap*255).astype(np.uint8)) ### Color map activation scale to 255 PIL image no_trans_heatmap = Image.fromarray((no_trans_heatmap*255).astype(np.uint8)) # Apply heatmap on image ### Create and RGBA image heatmap_on_image = Image.new("RGBA", org_im.size) ### org_im PIL converted onto RGBA and overlapped with heatmap on image heatmap_on_image = Image.alpha_composite(heatmap_on_image, org_im.convert('RGBA')) ### heatmap_on_image overlap with heatmap heatmap_on_image = Image.alpha_composite(heatmap_on_image, heatmap) return no_trans_heatmap, heatmap_on_image def heatmap_core(image:np.array,pathologies:list,target:str,model_cmaps:list,threshold = 0.3, alpha = 0.8,obscureness = 0.8,fontsize=14)->plt: """ Returns the heatmap of the image Args: image (np.array): Numpy Array Image (224x224) target (str): Pathology to select model_cmaps (list): colors to heatmap pathologies(list): List of pathologies threshold (float): Threshold to be more exigent or less exigent with the zone in which you are looking for alpha (float): the higher this value, the more intense is the colormaps obscureness (float) : the mhigher is this value the darker are the color maps fontsize (float): adjust the fontsize of the plot Original source: https://github.com/utkuozbulak/pytorch-cnn-visualizations Modifications by : ### TeamMIMICIV Added thresholding to activations. """ #### Initializing models models = OrderedDict() cams = OrderedDict() for model_name in ['densenet121-res224-mimic_nb', 'densenet121-res224-mimic_ch']: #### Adding the models and cams to the OrderedDict structure models,cams = takemodel(models,cams,weights=model_name) ### Get an image input_tensor = torch.from_numpy(image).unsqueeze(0) img = input_tensor.numpy()[0, 0, :, :] img = (img / 1024.0 / 2.0) + 0.5 img = np.clip(img, 0, 1) img = Image.fromarray(np.uint8(img * 255) , 'L') # using the variable axs for multiple Axes plt.figure(figsize=(10, 8)) i = 0 for model_name, model in models.items(): # get our model performance with torch.no_grad(): out = model(input_tensor).cpu() # reshape the dataset labels to match our model # xrv.datasets.relabel_dataset(model.pathologies, d_pc) # finds the index of the target based on the model pathologies assert target in pathologies,"Pathology input not in pathology maps" target_category = model.pathologies.index(target) grayscale_cam = cams[model_name](input_tensor=input_tensor, target_category=target_category) # In this example grayscale_cam has only one image in the batch: grayscale_cam = grayscale_cam[0, :] _, img = apply_colormap_on_image(img, grayscale_cam, model_cmaps[i].name, threshold=threshold,alpha=alpha) # add plot to add the color to the axis plt.plot(0, 0, '-', lw=6, color=model_cmaps[i](0.7), label=model_name) # what did we predict? prob = np.round(out[0].detach().numpy()[target_category], 4) i += 1 plt.legend(fontsize=fontsize) plt.imshow(img, cmap='bone') plt.axis('off') # plt.show() return plt def heatmap(img,target,threshold = 0.3, alpha = 0.8,obscureness = 0.8,fontsize=14): """ Returns the heatmap of the image Args: imgpath (str): Name of the image path target (str): Pathology to select threshold (float): Threshold to be more exigent or less exigent with the zone in which you are looking for alpha (float): the higher this value, the more intense is the colormaps obscureness (float) : the mhigher is this value the darker are the color maps fontsize (float): adjust the fontsize of the plot Original source: https://github.com/utkuozbulak/pytorch-cnn-visualizations Modifications by : ### TeamMIMICIV Added thresholding to activations. """ pathologies = ['Atelectasis', 'Consolidation', 'Pneumothorax','Edema', 'Effusion', 'Pneumonia', 'Cardiomegaly'] model_cmaps = [mpl_color_map.Purples, mpl_color_map.Greens_r] tr = transforms.Compose( [xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224, engine='cv2')] ) image = read_image(img,tr=tr) return heatmap_core(image,pathologies,target,model_cmaps,threshold = threshold, alpha = alpha,obscureness = obscureness,fontsize=fontsize) #### Initializing models def probtemp(image:np.array)->dict: """ Returns the output probabilities of two models Args: image (np.array): Numpy already scaled """ #### Initializing models models = OrderedDict() cams = OrderedDict() for model_name in ['densenet121-res224-mimic_nb', 'densenet121-res224-mimic_ch']: #### Adding the models and cams to the OrderedDict structure models,cams = takemodel(models,cams,weights=model_name) ### Get an image input_tensor = torch.from_numpy(image).unsqueeze(0) img = input_tensor.numpy()[0, 0, :, :] img = (img / 1024.0 / 2.0) + 0.5 img = np.clip(img, 0, 1) img = Image.fromarray(np.uint8(img * 255) , 'L') model_dics = {} for model_name, model in models.items(): # get our model performance with torch.no_grad(): out = model(input_tensor).cpu() model_dics[model_name] = {key:value for (key,value) in zip(model.pathologies, out.detach().numpy()[0]) if len(key)>2} return model_dics def getprobs(img): """ Returns the heatmap of the image Args: imgpath (str): Name of the image path target (str): Pathology to select threshold (float): Threshold to be more exigent or less exigent with the zone in which you are looking for alpha (float): the higher this value, the more intense is the colormaps obscureness (float) : the mhigher is this value the darker are the color maps fontsize (float): adjust the fontsize of the plot Original source: https://github.com/utkuozbulak/pytorch-cnn-visualizations Modifications by : ### TeamMIMICIV Added thresholding to activations. """ pathologies = ['Atelectasis', 'Consolidation', 'Pneumothorax','Edema', 'Effusion', 'Pneumonia', 'Cardiomegaly'] tr = transforms.Compose( [xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224, engine='cv2')] ) image = read_image(img,tr=tr) return probtemp(image) #### MORE FUNCTIONS.PY ### Get the probability of models def sortedmodels(probs,model_name): """ Sorts the probability model Args: probs (dict) : dictionary of model probabilities model_name (str) : name of the model """ ### Probability of the model promodels = probs[model_name] # Sort results by the descending probability order return dict(sorted(promodels.items(), key=operator.itemgetter(1),reverse=True)) def disprobs(model_probs,model_name,N): """ Displays the probability models and Sorts the probability model Args: model_probs (dict) : dictionary of model probabilities model_name (str) : name of the model """ exp1 = st.expander(f"Probabilities for {model_name}") pr = sortedmodels(model_probs,model_name) for cnt,(key,value) in enumerate(pr.items()): if cnt==N: break exp1.metric(label=key, value=str(cnt+1), delta=str(value)) def getfile(uploaded_file=None): """ Get the file uploaded """ if uploaded_file is not None: return extensionimages(uploaded_file) return extensionimages("example.dcm") ### Error in case we do not find compounds def error(option): option = str(option).replace(" ","%20") par3 = f'https://www.ebi.ac.uk/chembl/g/#search_results/all/query={option})' par2 = "".format(par3) par =par2 +"ChEBML" + "" st.markdown("

We have not found compounds for this illness; for more information visit this link: {}

".format(par), unsafe_allow_html=True) def main(): sys.path.insert(0,"..") ### Title st.set_page_config(layout="wide") header() ### Uploader uploaded_file = st.file_uploader("Choose an X-Ray image to detect anomalies of the chest (the file must be a dicom extension or jpg)",) #### Get the image imgdef = getfile(uploaded_file) __,col4,_,col5,_,col6,__ = st.columns((0.1,1,0.2,2.5,0.2,1,0.1)) col5.markdown("

Input Image

",unsafe_allow_html=True) with col5: ### Plot the input image fig, ax = plt.subplots() ax.imshow(imgdef,cmap="gray") st.pyplot(fig=fig) # Printing the possibility of having anomalies __,col1,_,col3,_,col2,__ = st.columns((0.1,1,0.2,2.5,0.2,1,0.1)) col3.markdown("

Anomaly Detection

",unsafe_allow_html=True) model_probs = getprobs(imgdef) option_anomaly,threshold,colors,obscureness,option,weight_cutoff,logp_cutoff,NumHDonors_cutoff,NumHAcceptors_cutoff,max_phase,N = controllers2(model_probs) ### MODEL 1 with col1: disprobs(model_probs,model_names[0],N) ### MODEL_2 with col2: disprobs(model_probs,model_names[1],N) ### ANOMALY HEATMAP with col3: plot = heatmap(imgdef,option_anomaly,threshold,colors,obscureness,14) st.pyplot(plot) df = getdrugs(option,max_phase) st.markdown("

Compounds for {}

".format(option),unsafe_allow_html=True) __,col10,col11,_,_,col12,__ = st.columns((0.1,0.8,2.5,0.2,0.2,1,0.1)) ### TREATMENT FILTERING if df is not None: #### Filter dataframe by controllers df_result = df[df["mol_weight"] < weight_cutoff] df_result2 = df_result[df_result["Logp"] < logp_cutoff] df_result3 = df_result2[df_result2["Donnors"] < NumHDonors_cutoff] df_result4 = df_result3[df_result3["Acceptors"] < NumHAcceptors_cutoff] if len(df_result4)==0: error(option) else: raw_html = mols2grid.display(df_result, mapping={"smiles": "SMILES","pref_name":"Name","Acceptors":"Acceptors","Donnors":"Donnors","Logp":"Logp","mol_weight":"mol_weight"}, subset=["img","Name"],tooltip=["Name","Acceptors","Donnors","Logp","mol_weight"],tooltip_placement="top",tooltip_trigger="click hover")._repr_html_() with col11: components.html(raw_html, width=900, height=900, scrolling=True) #### We do not find compounds for the anomaly else: error(option) if __name__=="__main__": main()