Rules99 commited on
Commit
7519073
1 Parent(s): e67c60d

YouRadiologist Update

Browse files
Files changed (3) hide show
  1. .gitignore +134 -0
  2. app.py +503 -225
  3. requirements.txt +13 -97
.gitignore ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+
132
+ ### other files
133
+ csvs/
134
+ model/
app.py CHANGED
@@ -1,51 +1,38 @@
1
- import streamlit as st
2
- from numpy import load
3
- from numpy import expand_dims
4
- from matplotlib import pyplot
5
- from PIL import Image, ImageDraw, ImageFont
6
- import numpy as np
7
  import os
8
- import os,sys
9
- sys.path.insert(0,"..")
10
- from glob import glob
11
- import torch
12
- import torchvision
13
  import sys
14
- import torch.nn.functional as F
15
- import torchxrayvision as xrv
16
- import pydicom as dicom
17
- import PIL # optional
18
  import pandas as pd
19
  import matplotlib.pyplot as plt
20
- import os
21
- import cv2
22
- import skimage
23
- from skimage.transform import rescale, resize, downscale_local_mean
 
 
 
 
 
 
 
 
 
24
  import operator
25
  import mols2grid
26
  import streamlit.components.v1 as components
27
  from rdkit import Chem
28
  from rdkit.Chem.Descriptors import ExactMolWt
29
  from chembl_webresource_client.new_client import new_client
 
30
 
31
- ### Title
32
- st.markdown("<h1 style='text-align: center;'>Chest Anomaly Identifier</h1>",unsafe_allow_html=True)
33
- ### Description
34
- st.markdown("""<p style='text-align: center;'>The goal of this application is mainly to help doctors to interpret
35
- Chest X-Ray Images, being able to find medical compounds in a quick way to deal with Chest's anomalies found</p>""",unsafe_allow_html=True)
36
-
37
- ### Image
38
- st.image("doctors.jpg")
39
-
40
- ### Uploder
41
- # st.markdown("""<p style='text-align: center;'>The goal of this application is mainly to help doctors to interpret
42
- # Chest X-Ray Images, being able to find medical compounds in a quick way to deal with Chest's anomalies found</p>""",unsafe_allow_html=True)
43
- uploaded_file = st.file_uploader("Choose an X-Ray image to detect anomalies of the chest (the file must be a dicom extension or jpg)")
44
-
45
-
46
-
47
 
48
- #### Get Compounds found
 
49
  @st.cache(allow_output_mutation=True)
50
  def getdrugs(name,phase):
51
  drug_indication = new_client.drug_indication
@@ -87,212 +74,503 @@ def getdrugs(name,phase):
87
  return df.loc[:,subs]
88
  except:
89
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- ### Read Chest X Ray Image
92
- def read_image(imgpath):
 
 
 
93
 
94
- if (str(imgpath).find("jpg")!=-1) or (str(imgpath).find("png")!=-1):
 
 
95
 
96
  # sample = Image.open("JPG_test/0c4eb1e1-b801903c-bcebe8a4-3da9cd3c-3b94a27c.jpg")
97
- sample = Image.open(imgpath)
98
  return np.array(sample)
99
- if str(imgpath).find("dcm")!=-1:
100
- img = dicom.dcmread(imgpath).pixel_array
 
101
  return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- ### Generate torchxrayvision model to find output probabilities
104
- def generatemodel(xrvmodel,wts):
105
- return xrvmodel(weights=wts)
106
- ### Transform the image to ouput some illness
107
- def transform2(img):
108
- input_tensor = torch.from_numpy(img).unsqueeze(0)
109
- img = input_tensor.numpy()[0, 0, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  img = (img / 1024.0 / 2.0) + 0.5
111
  img = np.clip(img, 0, 1)
112
  img = Image.fromarray(np.uint8(img * 255) , 'L')
113
- return img
114
- ### Transform the image to test an output image
115
- def transform(img):
116
 
117
- img = ((img-img.min())/(img.max()-img.min())*255)
118
-
119
-
120
- # img = (img / 1024.0 / 2.0) + 0.5
121
- # img = np.clip(img, 0, 1)
122
- # img = Image.fromarray(np.uint8(img * 255) , 'L')
123
- # print(img.shape)
124
- # img = skimage.io.imread("JPG_test/0c4eb1e1-b801903c-bcebe8a4-3da9cd3c-3b94a27c.jpg")
125
- # print(img.max())
126
- img = xrv.datasets.normalize(np.array(img), 255)
127
-
128
- # Check that images are 2D arrays
129
- if len(img.shape) > 2:
130
- img = img[:, :, 0]
131
- if len(img.shape) < 2:
132
- print("error, dimension lower than 2 for image")
133
-
134
- # Add color channel
135
- img = img[None, :, :]
136
-
137
- transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),
138
- xrv.datasets.XRayResizer(224,engine="cv2")])
139
-
140
- img = transform(img)
141
- return img
142
- ### Returns the output probabilities of having certain illnesses anomalies
143
- def testimage(model,img):
144
- # with torch.no_grad():
145
- model.eval()
146
- out = model(torch.from_numpy(img).unsqueeze(0)).cpu()
147
- # out = model(img).cpu()
148
- # out = torch.sigmoid(out)
149
-
150
- return {key:value for (key,value) in zip(model.pathologies, out.detach().numpy()[0]) if len(key)>2}
151
-
152
- ### Resize the model
153
- def outputprob2(img,pr_model,visimage=True):
154
- ### Read an image
155
- img = resize(img, (img.shape[0] // 2, img.shape[1] // 2),
156
- anti_aliasing=True)
 
 
157
 
158
- ### Preprocessmodel
159
- img_t = transform(img)
160
- ### Test an image
161
- return testimage(pr_model,img_t)
162
-
163
- ### Pipeline since we read an image until the ouput it is generated
164
- def outputprob(imgpath,pr_model,visimage=True):
165
- ### Read an image
166
- img = read_image(imgpath)
167
- if visimage:
168
- plt.imshow(img,cmap="gray")
169
- plt.show()
170
- ### Preprocessmodel
171
- img_t = transform(img)
172
- ### Test an image
173
- return testimage(pr_model,img_t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
 
 
 
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  ### Error in case we do not find compounds
177
  def error(option):
178
  option = str(option).replace(" ","%20")
179
- st.markdown(f"""
180
- We have not found compounds for this illness; for more information visit this link:
181
- [ChEMBL](https://www.ebi.ac.uk/chembl/g/#search_results/all/query={option})
182
- """, unsafe_allow_html=True)
183
-
184
- ### If you insert an image
185
- if uploaded_file is not None:
186
-
187
- #### Read an image
188
-
189
-
190
- imgdef = read_image(uploaded_file)
191
- else:
192
- imgdef = read_image("example.dcm")
193
- ## Controller header
194
-
195
- st.sidebar.markdown("<h1 style='text-align: center;'>Compound's filter</h1>",unsafe_allow_html=True)
196
- ## Write the compound
197
- st.sidebar.markdown('''
198
- <h4 style='text-align: center;'>This controller sidebar is used to filter the compounds by the following features</h4>
199
-
200
- - Molecular weight : is the weight of a compound in grame per mol
201
- - LogP : it measures how hydrophilic or hydrophobic a compound is
202
- - NumDonnors : number of chemical components that are able to deliver electrons to other chemical components
203
- - NumAcceptors : number of chemical components that are able to accept electrons to other chemical components
204
- ''',unsafe_allow_html=True)
205
- weight_cutoff = st.sidebar.slider(
206
- label="Molecular weight",
207
- min_value=0,
208
- max_value=1000,
209
- value=500,
210
- step=10,
211
- help="Look for compounds that have less or equal molecular weight than the value selected"
212
- )
213
- logp_cutoff = st.sidebar.slider(
214
- label="LogP",
215
- min_value=-10,
216
- max_value=10,
217
- value=5,
218
- step=1,
219
- help="Look for compounds that have less or equal logp than the value selected"
220
- )
221
- NumHDonors_cutoff = st.sidebar.slider(
222
- label="NumHDonors",
223
- min_value=0,
224
- max_value=15,
225
- value=5,
226
- step=1,
227
- help="Look for compounds that have less or equal donors weight than the value selected"
228
- )
229
- NumHAcceptors_cutoff = st.sidebar.slider(
230
- label="NumHAcceptors",
231
- min_value=0,
232
- max_value=20,
233
- value=10,
234
- step=1,
235
- help="Look for compounds that have less or equal acceptors weight than the value selected"
236
- )
237
- max_phase = st.sidebar.multiselect("Phase of the compound",
238
- ['1','2', '3', '4'],
239
- help="""
240
- - Phase 1 : Phase I of the compound in progress
241
- - Phase 2 : Phase II of the compound in progress
242
- - Phase 3 : Phase III of the compound in progress
243
- - Phase 4 : Approved compound
244
- """
245
- )
246
-
247
-
248
- ### Plot the input image
249
- fig, ax = plt.subplots()
250
- ax.imshow(imgdef,cmap="gray")
251
- st.pyplot(fig=fig)
252
- # Printing the possibility of having anomalies
253
- st.markdown("<h3 style='text-align: center;'>Possibility of anomalies</h3>",unsafe_allow_html=True)
254
- model = generatemodel(xrv.models.DenseNet,"densenet121-res224-mimic_ch") ### MIMIC MODEL+
255
- model.eval()
256
- pr = outputprob2(imgdef,model)
257
-
258
- # Sort results by the descending probability order
259
- pr = dict( sorted(pr.items(), key=operator.itemgetter(1),reverse=True))
260
- # Select the treatment
261
- option = st.sidebar.selectbox('Anomaly',list(pr.keys()),help='Select the illness or anomaly you want to treat')
262
- col1,col2,col3 = st.columns((1,1,1))
263
- cnt = 1
264
- for (key,value) in pr.items():
265
- if cnt%3==1:
266
- col1.metric(label=key, value=str(cnt), delta=str(value))
267
- if cnt%3==2:
268
- col2.metric(label=key, value=str(cnt), delta=str(value))
269
- if cnt%3==0:
270
- col3.metric(label=key, value=str(cnt), delta=str(value))
271
- cnt+=1
272
- # temp = st.expander("Compunds to take care of {}".format(key))
273
- #### Get the compounds for the anomaly selected
274
- df = getdrugs(option,max_phase)
275
- st.markdown("<h3 style='text-align: center;'>Compounds for {}</h3>".format(option),unsafe_allow_html=True)
276
- ### If exists the compounds
277
- if df is not None:
278
-
279
- #### Filter dataframe by controllers
280
- df_result = df[df["mol_weight"] < weight_cutoff]
281
- df_result2 = df_result[df_result["Logp"] < logp_cutoff]
282
- df_result3 = df_result2[df_result2["Donnors"] < NumHDonors_cutoff]
283
- df_result4 = df_result3[df_result3["Acceptors"] < NumHAcceptors_cutoff]
284
-
285
-
286
-
287
- if len(df_result4)==0:
288
 
289
  error(option)
290
- else:
291
- raw_html = mols2grid.display(df_result, mapping={"smiles": "SMILES","pref_name":"Name","Acceptors":"Acceptors","Donnors":"Donnors","Logp":"Logp","mol_weight":"mol_weight"},
292
- subset=["img","Name"],tooltip=["Name","Acceptors","Donnors","Logp","mol_weight"],tooltip_placement="top",tooltip_trigger="click hover")._repr_html_()
293
-
294
- components.html(raw_html, width=900, height=900, scrolling=True)
295
- #### We do not find compounds for the anomaly
296
- else:
297
- error(option)
298
 
 
 
 
1
+ ### FRAMEWORKS AND DEPENDENCIES
2
+ import copy
 
 
 
 
3
  import os
 
 
 
 
 
4
  import sys
5
+ from collections import OrderedDict
6
+ from pathlib import Path
7
+ import numpy as np
 
8
  import pandas as pd
9
  import matplotlib.pyplot as plt
10
+ import matplotlib.cm as mpl_color_map
11
+ from PIL import Image, ImageFilter
12
+ from collections import OrderedDict
13
+ import matplotlib as mpl
14
+ import torch
15
+ import torch.nn as nn
16
+ from torchvision import datasets, models, transforms
17
+ import torchxrayvision as xrv
18
+ from pytorch_grad_cam import GradCAM
19
+ # Other methods available: ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
20
+ from pytorch_grad_cam.utils.image import show_cam_on_image
21
+ from skimage.io import imread
22
+ import pydicom as dicom
23
  import operator
24
  import mols2grid
25
  import streamlit.components.v1 as components
26
  from rdkit import Chem
27
  from rdkit.Chem.Descriptors import ExactMolWt
28
  from chembl_webresource_client.new_client import new_client
29
+ import streamlit as st
30
 
31
+ ####UTILS.PY
32
+ model_names = ['densenet121-res224-mimic_nb', 'densenet121-res224-mimic_ch']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ #### FUNCTIONS FOR STREAMLIT
35
+ ### Cache Drugs (Get Compounds found)
36
  @st.cache(allow_output_mutation=True)
37
  def getdrugs(name,phase):
38
  drug_indication = new_client.drug_indication
 
74
  return df.loc[:,subs]
75
  except:
76
  return None
77
+ ### Title
78
+ def header():
79
+
80
+ st.markdown("<h1 style='text-align: center;'>Chest Anomaly Identifier</h1>",unsafe_allow_html=True)
81
+ ### Description
82
+ st.markdown("""<p style='text-align: center;'>This is a pocket application that is mainly focused on aiding medical
83
+ professionals on their diagnostics and treatments for chest anomalies based on chest X-Rays. On this application, users
84
+ can upload a chest X-Ray image and a deep learning model will output the probability of 14 different anomalies taking
85
+ place on that image</p>""",unsafe_allow_html=True)
86
+
87
+ ### Image
88
+ st.image("doctors.jpg")
89
+ ### Controllers
90
+ def controllers2(model_probs):
91
+
92
+
93
+
94
+
95
+ # Select the anomaly to detect
96
+ st.sidebar.markdown("<h1 style='text-align: center;'>Anomaly detection</h1>",unsafe_allow_html=True)
97
+ option_anomaly = st.sidebar.selectbox('Select Anomaly to detect',['Atelectasis', 'Consolidation', 'Pneumothorax','Edema', 'Effusion', 'Pneumonia', 'Cardiomegaly'],help='Select the anomaly you want to detect')
98
+ # Filtering anomalies
99
+ st.sidebar.markdown('''
100
+ <h4 style='text-align: center;'>This controller is used to filter anomaly detection </h4>
101
+
102
+ - N : Select the number of most likely anomalies you want to detect
103
+ - Threshold : It measures how strict you are with the threshold
104
+ - Colors : For color intensity of anomaly detection
105
+ - Obscureness : For darker or lighter colors
106
+
107
+
108
+ ''',unsafe_allow_html=True)
109
+
110
+ N = st.sidebar.slider(label="N",min_value=1,max_value=5,value=3,step=1,help="Select the number of most likely anomalies you want to detect")
111
+ threshold = st.sidebar.slider(label="Threshold",min_value=0.0,max_value=1.0,value=0.3,step=0.1,help="Select the degree of confidence you want to detect. The more is the value the more strict you are in your detection")
112
+ colors = st.sidebar.slider("Intense Colors",min_value=0.0,max_value=1.0,value=0.6,step=0.1,help="Select the color intensity you want to display at the time on detecting an anomaly. The higuer the value, the more intense the color")
113
+ obscureness = st.sidebar.slider("Obscureness",min_value=0.0,max_value=1.0,value=0.8,step=0.1,help="Select the obscureness you want your colors have. The higuer the value, the more obscure is the color")
114
+
115
+
116
+ # Select the treatment
117
+
118
+ st.sidebar.markdown("<h1 style='text-align: center;'>Anomaly Treatment</h1>",unsafe_allow_html=True)
119
+ option = st.sidebar.selectbox('Select the anomaly for treatment',list(model_probs[model_names[0]].keys()),help='Select the anomaly you want to treat')
120
+
121
+
122
+
123
+ #### Filtering treatments
124
+ st.sidebar.markdown("<h1 style='text-align: center;'>Compound's filter</h1>",unsafe_allow_html=True)
125
+ ## Write the compound
126
+ st.sidebar.markdown('''
127
+ <h4 style='text-align: center;'>This controller sidebar is used to filter the compounds by the following features</h4>
128
+
129
+ - Molecular weight : is the weight of a compound in grame per mol
130
+ - LogP : it measures how hydrophilic or hydrophobic a compound is
131
+ - NumDonnors : number of chemical components that are able to deliver electrons to other chemical components
132
+ - NumAcceptors : number of chemical components that are able to accept electrons to other chemical components
133
+ - MaxPhase : select the phase in which the compound is stablished
134
+ ''',unsafe_allow_html=True)
135
+ weight_cutoff = st.sidebar.slider(
136
+ label="Molecular weight",
137
+ min_value=0,
138
+ max_value=1000,
139
+ value=500,
140
+ step=10,
141
+ help="Look for compounds that have less or equal molecular weight than the value selected"
142
+ )
143
+ logp_cutoff = st.sidebar.slider(
144
+ label="LogP",
145
+ min_value=-10,
146
+ max_value=10,
147
+ value=5,
148
+ step=1,
149
+ help="Look for compounds that have less or equal logp than the value selected"
150
+ )
151
+ NumHDonors_cutoff = st.sidebar.slider(
152
+ label="NumHDonors",
153
+ min_value=0,
154
+ max_value=15,
155
+ value=5,
156
+ step=1,
157
+ help="Look for compounds that have less or equal donors weight than the value selected"
158
+ )
159
+ NumHAcceptors_cutoff = st.sidebar.slider(
160
+ label="NumHAcceptors",
161
+ min_value=0,
162
+ max_value=20,
163
+ value=10,
164
+ step=1,
165
+ help="Look for compounds that have less or equal acceptors weight than the value selected"
166
+ )
167
+ max_phase = st.sidebar.multiselect("Select Phase of the compound",
168
+ ['1','2', '3', '4'],
169
+ help="""
170
+ - Phase 1 : Phase I of the compound in progress
171
+ - Phase 2 : Phase II of the compound in progress
172
+ - Phase 3 : Phase III of the compound in progress
173
+ - Phase 4 : Approved compound """
174
+ )
175
+
176
+ return option_anomaly,threshold,colors,obscureness,option,weight_cutoff,logp_cutoff,NumHDonors_cutoff,NumHAcceptors_cutoff,max_phase,N
177
+
178
+
179
+
180
+ ### MODEL.PY
181
+
182
+ def takemodel(models:OrderedDict,cams:OrderedDict,weights="mimic_ch"):
183
+ """
184
+ Define models and cams of each model; tools useful for heatmap
185
+ Args:
186
+ models (OrderedDict[xrv.models.DenseNet]): the CNN of the model
187
+ cams (OrderedDict[GradCam]): Useful tool to make the heatmap
188
+ weights (str): Name of the pretrained model weights
189
+ """
190
+ models[weights] = xrv.models.DenseNet(weights=weights)
191
+ models[weights].eval()
192
+ target_layer = models[weights].features[-2]
193
+ cams[weights] = GradCAM(models[weights], target_layer, use_cuda=False)
194
+ return models,cams
195
+ #### Read the image | Normalize
196
+ def normalize(sample, maxval):
197
+ """
198
+ Scales images to be roughly [-1024 1024].
199
+ Args:
200
+ image (dicom,jp,png): image
201
+ maxval (int): maxvalue of the dicom image
202
+
203
+ From torchxrayvision
204
+ """
205
+
206
+ if sample.max() > maxval:
207
+ raise Exception("max image value ({}) higher than expected bound ({}).".format(sample.max(), maxval))
208
+
209
+ sample = (2 * (sample.astype(np.float32) / maxval) - 1.) * 1024
210
+ #sample = sample / np.std(sample)
211
+ return sample
212
 
213
+ def extensionimages(image_path):
214
+ """
215
+ Read Image of jpg dicom or png if it does not find the image returns skimage.io.imread(imgpath)
216
+ Args:
217
+ image_path (str): path of the image
218
 
219
+ """
220
+
221
+ if (str(image_path).find("jpg")!=-1) or (str(image_path).find("png")!=-1):
222
 
223
  # sample = Image.open("JPG_test/0c4eb1e1-b801903c-bcebe8a4-3da9cd3c-3b94a27c.jpg")
224
+ sample = Image.open(image_path)
225
  return np.array(sample)
226
+ if str(image_path).find("dcm")!=-1:
227
+ img = dicom.dcmread(image_path).pixel_array
228
+
229
  return img
230
+ else:
231
+ return imread(image_path)
232
+
233
+
234
+ def read_image(img, tr=None,visualize=True):
235
+ """
236
+ Scales images to be roughly [-1024 1024].
237
+ Args:
238
+ image_path (str): path of the image
239
+ From torchxrayvision
240
+ """
241
+ # img = extensionimages(image_path)
242
+ ### If black image has 3 dim get just one channel
243
+
244
+
245
+ try:
246
+ img = img[:, :, 0]
247
+ ### Otherwise we take 2 channels
248
+ except IndexError:
249
+ pass
250
+ # Another option will be equalizing the image
251
+ # img = cv2.equalizeHist(img.astype(np.uint8))
252
+ img = ((img-img.min())/(img.max()-img.min())*255)
253
+ ### Normalize to values -1024 1024
254
+ img = normalize(img, 255)
255
+ # print(img.min(),img.max())
256
+ # Add color channel
257
+ img = img[None, :, :]
258
+ if tr is not None:
259
+ img = tr(img)
260
+ else:
261
+ raise Exception("You should pass a transformer to downsample the images")
262
+ return img
263
+
264
+ #### Applly colormap on image
265
+ def apply_colormap_on_image(org_im, activation, colormap_name, threshold=0.3,alpha=0.6):
266
+ """
267
+ Apply heatmap on image
268
+ Args:
269
+ org_img (PIL img): Original image (224x224)
270
+ activation_map (numpy arr): Activation map (grayscale) 0-255 (224x224)
271
+ colormap_name (str): Name of the colormap (colormap_name)
272
+ threshold (float): threshold at which to overlay heatmap (threshold that anomaly must surpass in terms of probability)
273
+ alpha (float): adjust the intense in which the model predicts
274
+ Original source: https://github.com/utkuozbulak/pytorch-cnn-visualizations
275
+
276
+ Added thresholding to activations.
277
+ """
278
+ ### Grayscale_cam
279
+ grayscale_cam = copy.deepcopy(activation)
280
+ # Get colormap just color type
281
+ color_map = mpl_color_map.get_cmap(colormap_name)
282
+ # Like map the activation function to the color map
283
+
284
+ no_trans_heatmap = color_map(activation)
285
+ ### Not_trans_heatmap output (224x224x4 channels) (HSV-alpha channels)
286
+ ### H --> channel 0 H --> channel 1 H --> channel 2 alpha --> channel 3
287
 
288
+ # Change alpha channel in colormap to make sure original image is displayed deepcopy
289
+ alpha_channel = 3
290
+ heatmap = copy.copy(no_trans_heatmap)
291
+ heatmap[:, :, alpha_channel] = alpha
292
+
293
+ # set to fully transparent if there is a very low activation (if the activation map is lower than the threshold)
294
+ idx = (grayscale_cam <= threshold)
295
+ # convert to a 3d index the shape of the image (expand the image by arrays)
296
+ # Input shape 224x244 --- Output Shape 224x224x1
297
+ ignore_idx = np.expand_dims(np.zeros(grayscale_cam.shape, dtype=bool), 2)
298
+
299
+ ### Idx is the four fimenation of the heatmap concatenate 224x224x3 with 224x224x1 ---> 224x224x4
300
+ idx = np.concatenate([ignore_idx]*3 + [np.expand_dims(idx, 2)], axis=2)
301
+
302
+
303
+ heatmap[idx] = 0
304
+ ### Inputs 224x224x4
305
+ ### Scale to a 255 integer and map to PIL image
306
+ heatmap = Image.fromarray((heatmap*255).astype(np.uint8))
307
+ ### Color map activation scale to 255 PIL image
308
+ no_trans_heatmap = Image.fromarray((no_trans_heatmap*255).astype(np.uint8))
309
+
310
+ # Apply heatmap on image
311
+ ### Create and RGBA image
312
+ heatmap_on_image = Image.new("RGBA", org_im.size)
313
+ ### org_im PIL converted onto RGBA and overlapped with heatmap on image
314
+ heatmap_on_image = Image.alpha_composite(heatmap_on_image, org_im.convert('RGBA'))
315
+ ### heatmap_on_image overlap with heatmap
316
+ heatmap_on_image = Image.alpha_composite(heatmap_on_image, heatmap)
317
+ return no_trans_heatmap, heatmap_on_image
318
+
319
+
320
+
321
+ def heatmap_core(image:np.array,pathologies:list,target:str,model_cmaps:list,threshold = 0.3, alpha = 0.8,obscureness = 0.8,fontsize=14)->plt:
322
+ """
323
+ Returns the heatmap of the image
324
+ Args:
325
+ image (np.array): Numpy Array Image (224x224)
326
+ target (str): Pathology to select
327
+ model_cmaps (list): colors to heatmap
328
+ pathologies(list): List of pathologies
329
+ threshold (float): Threshold to be more exigent or less exigent with the zone in which you are looking for
330
+ alpha (float): the higher this value, the more intense is the colormaps
331
+ obscureness (float) : the mhigher is this value the darker are the color maps
332
+ fontsize (float): adjust the fontsize of the plot
333
+ Original source: https://github.com/utkuozbulak/pytorch-cnn-visualizations
334
+ Modifications by : ### TeamMIMICIV
335
+
336
+ Added thresholding to activations.
337
+ """
338
+
339
+ #### Initializing models
340
+ models = OrderedDict()
341
+ cams = OrderedDict()
342
+ for model_name in ['densenet121-res224-mimic_nb', 'densenet121-res224-mimic_ch']:
343
+ #### Adding the models and cams to the OrderedDict structure
344
+ models,cams = takemodel(models,cams,weights=model_name)
345
+ ### Get an image
346
+ input_tensor = torch.from_numpy(image).unsqueeze(0)
347
+
348
+ img = input_tensor.numpy()[0, 0, :, :]
349
  img = (img / 1024.0 / 2.0) + 0.5
350
  img = np.clip(img, 0, 1)
351
  img = Image.fromarray(np.uint8(img * 255) , 'L')
 
 
 
352
 
353
+ # using the variable axs for multiple Axes
354
+ plt.figure(figsize=(10, 8))
355
+
356
+ i = 0
357
+ for model_name, model in models.items():
358
+ # get our model performance
359
+ with torch.no_grad():
360
+ out = model(input_tensor).cpu()
361
+
362
+ # reshape the dataset labels to match our model
363
+ # xrv.datasets.relabel_dataset(model.pathologies, d_pc)
364
+
365
+ # finds the index of the target based on the model pathologies
366
+ assert target in pathologies,"Pathology input not in pathology maps"
367
+ target_category = model.pathologies.index(target)
368
+ grayscale_cam = cams[model_name](input_tensor=input_tensor, target_category=target_category)
369
+ # In this example grayscale_cam has only one image in the batch:
370
+ grayscale_cam = grayscale_cam[0, :]
371
+
372
+ _, img = apply_colormap_on_image(img, grayscale_cam, model_cmaps[i].name, threshold=threshold,alpha=alpha)
373
+
374
+ # add plot to add the color to the axis
375
+ plt.plot(0, 0, '-', lw=6, color=model_cmaps[i](0.7), label=model_name)
376
+
377
+ # what did we predict?
378
+ prob = np.round(out[0].detach().numpy()[target_category], 4)
379
+
380
+ i += 1
381
+
382
+ plt.legend(fontsize=fontsize)
383
+ plt.imshow(img, cmap='bone')
384
+ plt.axis('off')
385
+ # plt.show()
386
+ return plt
387
+
388
+
389
+ def heatmap(img,target,threshold = 0.3, alpha = 0.8,obscureness = 0.8,fontsize=14):
390
+ """
391
+ Returns the heatmap of the image
392
+ Args:
393
+ imgpath (str): Name of the image path
394
+ target (str): Pathology to select
395
 
396
+ threshold (float): Threshold to be more exigent or less exigent with the zone in which you are looking for
397
+ alpha (float): the higher this value, the more intense is the colormaps
398
+ obscureness (float) : the mhigher is this value the darker are the color maps
399
+ fontsize (float): adjust the fontsize of the plot
400
+ Original source: https://github.com/utkuozbulak/pytorch-cnn-visualizations
401
+ Modifications by : ### TeamMIMICIV
402
+ Added thresholding to activations.
403
+ """
404
+ pathologies = ['Atelectasis', 'Consolidation', 'Pneumothorax','Edema', 'Effusion', 'Pneumonia', 'Cardiomegaly']
405
+ model_cmaps = [mpl_color_map.Purples, mpl_color_map.Greens_r]
406
+ tr = transforms.Compose(
407
+ [xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224, engine='cv2')]
408
+ )
409
+ image = read_image(img,tr=tr)
410
+ return heatmap_core(image,pathologies,target,model_cmaps,threshold = threshold, alpha = alpha,obscureness = obscureness,fontsize=fontsize)
411
+
412
+
413
+ #### Initializing models
414
+ def probtemp(image:np.array)->dict:
415
+ """
416
+ Returns the output probabilities of two models
417
+ Args:
418
+ image (np.array): Numpy already scaled
419
+ """
420
+ #### Initializing models
421
+ models = OrderedDict()
422
+ cams = OrderedDict()
423
+
424
+ for model_name in ['densenet121-res224-mimic_nb', 'densenet121-res224-mimic_ch']:
425
+ #### Adding the models and cams to the OrderedDict structure
426
+ models,cams = takemodel(models,cams,weights=model_name)
427
+ ### Get an image
428
+ input_tensor = torch.from_numpy(image).unsqueeze(0)
429
 
430
+ img = input_tensor.numpy()[0, 0, :, :]
431
+ img = (img / 1024.0 / 2.0) + 0.5
432
+ img = np.clip(img, 0, 1)
433
+ img = Image.fromarray(np.uint8(img * 255) , 'L')
434
 
435
+ model_dics = {}
436
+ for model_name, model in models.items():
437
+ # get our model performance
438
+ with torch.no_grad():
439
+ out = model(input_tensor).cpu()
440
+ model_dics[model_name] = {key:value for (key,value) in zip(model.pathologies, out.detach().numpy()[0]) if len(key)>2}
441
+ return model_dics
442
+ def getprobs(img):
443
+ """
444
+ Returns the heatmap of the image
445
+ Args:
446
+ imgpath (str): Name of the image path
447
+ target (str): Pathology to select
448
+
449
+ threshold (float): Threshold to be more exigent or less exigent with the zone in which you are looking for
450
+ alpha (float): the higher this value, the more intense is the colormaps
451
+ obscureness (float) : the mhigher is this value the darker are the color maps
452
+ fontsize (float): adjust the fontsize of the plot
453
+ Original source: https://github.com/utkuozbulak/pytorch-cnn-visualizations
454
+ Modifications by : ### TeamMIMICIV
455
+ Added thresholding to activations.
456
+ """
457
+ pathologies = ['Atelectasis', 'Consolidation', 'Pneumothorax','Edema', 'Effusion', 'Pneumonia', 'Cardiomegaly']
458
+ tr = transforms.Compose(
459
+ [xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224, engine='cv2')]
460
+ )
461
+ image = read_image(img,tr=tr)
462
+ return probtemp(image)
463
+
464
+
465
+
466
+
467
+ #### MORE FUNCTIONS.PY
468
+ ### Get the probability of models
469
+ def sortedmodels(probs,model_name):
470
+ """
471
+ Sorts the probability model
472
+ Args:
473
+ probs (dict) : dictionary of model probabilities
474
+ model_name (str) : name of the model
475
+ """
476
+ ### Probability of the model
477
+ promodels = probs[model_name]
478
+ # Sort results by the descending probability order
479
+ return dict(sorted(promodels.items(), key=operator.itemgetter(1),reverse=True))
480
+ def disprobs(model_probs,model_name,N):
481
+ """
482
+ Displays the probability models and Sorts the probability model
483
+ Args:
484
+ model_probs (dict) : dictionary of model probabilities
485
+ model_name (str) : name of the model
486
+ """
487
+ exp1 = st.expander(f"Probabilities for {model_name}")
488
+ pr = sortedmodels(model_probs,model_name)
489
+ for cnt,(key,value) in enumerate(pr.items()):
490
+ if cnt==N:
491
+ break
492
+ exp1.metric(label=key, value=str(cnt+1), delta=str(value))
493
+
494
+ def getfile(uploaded_file=None):
495
+ """
496
+ Get the file uploaded
497
+ """
498
+ if uploaded_file is not None:
499
+ return extensionimages(uploaded_file)
500
+ return extensionimages("example.dcm")
501
  ### Error in case we do not find compounds
502
  def error(option):
503
  option = str(option).replace(" ","%20")
504
+ par3 = f'https://www.ebi.ac.uk/chembl/g/#search_results/all/query={option})'
505
+ par2 = "<a href = {} >".format(par3)
506
+ par =par2 +"ChEBML" + "</a>"
507
+
508
+ st.markdown("<p style='text-align: center;'>We have not found compounds for this illness; for more information visit this link: {}</p>".format(par), unsafe_allow_html=True)
509
+
510
+ def main():
511
+
512
+ sys.path.insert(0,"..")
513
+ ### Title
514
+ st.set_page_config(layout="wide")
515
+ header()
516
+ ### Uploader
517
+ uploaded_file = st.file_uploader("Choose an X-Ray image to detect anomalies of the chest (the file must be a dicom extension or jpg)",)
518
+ #### Get the image
519
+
520
+ imgdef = getfile(uploaded_file)
521
+ __,col4,_,col5,_,col6,__ = st.columns((0.1,1,0.2,2.5,0.2,1,0.1))
522
+ col5.markdown("<h3 style='text-align: center;'>Input Image</h3>",unsafe_allow_html=True)
523
+ with col5:
524
+ ### Plot the input image
525
+ fig, ax = plt.subplots()
526
+ ax.imshow(imgdef,cmap="gray")
527
+ st.pyplot(fig=fig)
528
+ # Printing the possibility of having anomalies
529
+
530
+ __,col1,_,col3,_,col2,__ = st.columns((0.1,1,0.2,2.5,0.2,1,0.1))
531
+ col3.markdown("<h3 style='text-align: center;'>Anomaly Detection</h3>",unsafe_allow_html=True)
532
+ model_probs = getprobs(imgdef)
533
+ option_anomaly,threshold,colors,obscureness,option,weight_cutoff,logp_cutoff,NumHDonors_cutoff,NumHAcceptors_cutoff,max_phase,N = controllers2(model_probs)
534
+ ### MODEL 1
535
+ with col1:
536
+ disprobs(model_probs,model_names[0],N)
537
+ ### MODEL_2
538
+ with col2:
539
+ disprobs(model_probs,model_names[1],N)
540
+
541
+ ### ANOMALY HEATMAP
542
+ with col3:
543
+ plot = heatmap(imgdef,option_anomaly,threshold,colors,obscureness,14)
544
+ st.pyplot(plot)
545
+ df = getdrugs(option,max_phase)
546
+
547
+ st.markdown("<h3 style='text-align: center;'>Compounds for {}</h3>".format(option),unsafe_allow_html=True)
548
+ __,col10,col11,_,_,col12,__ = st.columns((0.1,0.8,2.5,0.2,0.2,1,0.1))
549
+
550
+ ### TREATMENT FILTERING
551
+ if df is not None:
552
+
553
+ #### Filter dataframe by controllers
554
+ df_result = df[df["mol_weight"] < weight_cutoff]
555
+ df_result2 = df_result[df_result["Logp"] < logp_cutoff]
556
+ df_result3 = df_result2[df_result2["Donnors"] < NumHDonors_cutoff]
557
+ df_result4 = df_result3[df_result3["Acceptors"] < NumHAcceptors_cutoff]
558
+
559
+
560
+
561
+ if len(df_result4)==0:
562
+
563
+ error(option)
564
+ else:
565
+ raw_html = mols2grid.display(df_result, mapping={"smiles": "SMILES","pref_name":"Name","Acceptors":"Acceptors","Donnors":"Donnors","Logp":"Logp","mol_weight":"mol_weight"},
566
+ subset=["img","Name"],tooltip=["Name","Acceptors","Donnors","Logp","mol_weight"],tooltip_placement="top",tooltip_trigger="click hover")._repr_html_()
567
+ with col11:
568
+
569
+ components.html(raw_html, width=900, height=900, scrolling=True)
570
+ #### We do not find compounds for the anomaly
571
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
  error(option)
 
 
 
 
 
 
 
 
574
 
575
+ if __name__=="__main__":
576
+ main()
requirements.txt CHANGED
@@ -1,105 +1,21 @@
1
- rdkit-pypi
2
- mols2grid
3
- opencv-python-headless
4
- altair==4.1.0
5
- argcomplete==1.12.3
6
- argon2-cffi==21.1.0
7
- astor==0.8.1
8
- attrs==21.2.0
9
- backcall==0.2.0
10
- backports.zoneinfo==0.2.1
11
- base58==2.1.1
12
- bleach==4.1.0
13
- blinker==1.4
14
- Bottleneck==1.3.2
15
- cachetools==4.2.4
16
- certifi==2021.10.8
17
- cffi==1.15.0
18
- charset-normalizer==2.0.9
19
- chembl-webresource-client==0.10.7
20
- click==7.1.2
21
- colorama==0.4.4
22
- debugpy==1.5.1
23
- decorator==5.1.0
24
- defusedxml==0.7.1
25
  easydict==1.9
26
- entrypoints==0.3
27
- fonttools==4.25.0
28
- gitdb==4.0.9
29
- GitPython==3.1.24
30
- idna==3.3
31
- imageio==2.13.2
32
- importlib-metadata==4.8.2
33
- importlib-resources==5.4.0
34
- ipykernel==6.6.0
35
- ipython==7.30.1
36
- ipython-genutils==0.2.0
37
- ipywidgets==7.6.5
38
- itsdangerous==2.0.1
39
- jedi==0.18.1
40
- jsonschema==4.2.1
41
- jupyter-client==7.1.0
42
- jupyter-core==4.9.1
43
- jupyterlab-pygments==0.1.2
44
- jupyterlab-widgets==1.0.2
45
- matplotlib-inline==0.1.3
46
- mistune==0.8.4
47
- mkl-fft==1.3.1
48
- mkl-service==2.4.0
49
- munkres==1.1.4
50
- nbclient==0.5.9
51
- nbconvert==6.3.0
52
- nbformat==5.1.3
53
- nest-asyncio==1.5.4
54
- networkx==2.6.3
55
- notebook==6.4.6
56
- olefile==0.46
57
- pandocfilters==1.5.0
58
- parso==0.8.3
59
- pickleshare==0.7.5
60
  Pillow==8.4.0
61
- prometheus-client==0.12.0
62
- prompt-toolkit==3.0.23
63
- protobuf==3.19.1
64
- pyarrow==6.0.1
65
- pycparser==2.21
66
- pydeck==0.7.1
67
  pydicom==2.2.2
68
- Pygments==2.10.0
69
- Pympler==0.9
70
- pyparsing==3.0.6
71
- pyrsistent==0.18.0
72
- pytz==2021.3
73
- pytz-deprecation-shim==0.1.0.post0
74
- PyWavelets==1.2.0
75
- PyYAML==6.0
76
- pyzmq==22.3.0
77
- requests==2.26.0
78
- requests-cache==0.7.5
79
- scikit-image==0.19.0
80
  scipy==1.7.3
81
- Send2Trash==1.8.0
82
- smmap==5.0.0
83
  streamlit==1.2.0
84
- terminado==0.12.1
85
- testpath==0.5.0
86
- tifffile==2021.11.2
87
- toml==0.10.2
88
- toolz==0.11.2
89
  torch==1.10.0
 
90
  torchvision==0.11.1
91
  torchxrayvision==0.0.32
92
- tqdm==4.62.3
93
- traitlets==5.1.1
94
- typing_extensions==4.0.1
95
- tzdata==2021.5
96
- tzlocal==4.1
97
- url-normalize==1.4.3
98
- urllib3==1.26.7
99
- validators==0.18.2
100
- watchdog==2.1.6
101
- wcwidth==0.2.5
102
- webencodings==0.5.1
103
- widgetsnbextension==3.5.2
104
- wincertstore==0.2
105
- zipp==3.6.0
 
1
+ chembl_webresource_client==0.10.7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  easydict==1.9
3
+ grad_cam==1.3.5
4
+ matplotlib==3.5.0
5
+ mols2grid==0.1.0
6
+ numpy==1.21.2
7
+ opencv_python==4.5.4.60
8
+ pandas==1.3.4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  Pillow==8.4.0
 
 
 
 
 
 
10
  pydicom==2.2.2
11
+ rdkit==2009.Q1-1
12
+ scikit_image==0.19.0
13
+ scikit_learn==1.0.1
 
 
 
 
 
 
 
 
 
14
  scipy==1.7.3
15
+ skimage==0.0
 
16
  streamlit==1.2.0
17
+ tensorboardX==2.4.1
 
 
 
 
18
  torch==1.10.0
19
+ torchsummary==1.5.1
20
  torchvision==0.11.1
21
  torchxrayvision==0.0.32