pamixsun commited on
Commit
f2c28c8
1 Parent(s): 3d86810

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2023, Xu Sun.
2
+
3
+ # This program is licensed under the Apache License version 2.
4
+ # See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0.txt> for full license details.
5
+
6
+ import torch
7
+ import numpy as np
8
+
9
+ import matplotlib.pyplot as plt
10
+ import streamlit as st
11
+
12
+ from PIL import Image
13
+ from lib.glaucoma import GlaucomaModel
14
+
15
+ run_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+
17
+
18
+ def main():
19
+ # Wide mode
20
+ st.set_page_config(layout="wide")
21
+
22
+ # Designing the interface
23
+ st.title("Glaucoma Screening from Retinal Fundus Images")
24
+ # For newline
25
+ st.write('\n')
26
+ # Author info
27
+ st.write('Developed by X. Sun. Find more info about me: https://pamixsun.github.io')
28
+ # For newline
29
+ st.write('\n')
30
+ # Instructions
31
+ st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
32
+ # Set the columns
33
+ cols = st.beta_columns((1, 1))
34
+ cols[0].subheader("Input image")
35
+ cols[1].subheader("Class activation map")
36
+
37
+ # set the visualization figure
38
+ fig, ax = plt.subplots()
39
+
40
+ # Sidebar
41
+ # File selection
42
+ st.sidebar.title("Image selection")
43
+ # Disabling warning
44
+ st.set_option('deprecation.showfileUploaderEncoding', False)
45
+ # Choose your own image
46
+ uploaded_file = st.sidebar.file_uploader("Upload image", type=['png', 'jpeg', 'jpg'])
47
+ if uploaded_file is not None:
48
+ # read the upload image
49
+ image = Image.open(uploaded_file).convert('RGB')
50
+ image = np.array(image).astype(np.uint8)
51
+ # page_idx = 0
52
+ ax.imshow(image)
53
+ ax.axis('off')
54
+ cols[0].pyplot(fig)
55
+
56
+ # For newline
57
+ st.sidebar.write('\n')
58
+
59
+ # actions
60
+ if st.sidebar.button("Analyze image"):
61
+
62
+ if uploaded_file is None:
63
+ st.sidebar.write("Please upload an image")
64
+
65
+ else:
66
+ with st.spinner('Loading model...'):
67
+ # load model
68
+ model = GlaucomaModel(device=run_device)
69
+
70
+ with st.spinner('Analyzing...'):
71
+ # Forward the image to the model and get results
72
+ disease_idx, cam = model.process(image)
73
+
74
+ # visualize results
75
+ # fig, ax = plt.subplots()
76
+
77
+ # plot the stitched image
78
+ ax.imshow(cam)
79
+ ax.axis('off')
80
+ cols[1].pyplot(fig)
81
+
82
+ # Display JSON
83
+ st.subheader(" Screening results:")
84
+ st.write('\n')
85
+ st.markdown(f"{model.id2label[disease_idx]}")
86
+
87
+
88
+ if __name__ == '__main__':
89
+ main()