Talal commited on
Commit
f3200e4
1 Parent(s): ac14de4
Files changed (5) hide show
  1. bg.png +0 -0
  2. main.py +59 -0
  3. model_0001999.pth +3 -0
  4. requirements.txt +9 -0
  5. util.py +103 -0
bg.png ADDED
main.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from detectron2.config import get_cfg
3
+ from detectron2.engine import DefaultPredictor
4
+ from detectron2 import model_zoo
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ from util import visualize, set_background
9
+
10
+
11
+ set_background('bg.png')
12
+
13
+
14
+ # set title
15
+ st.title('Brain MRI tumor detection')
16
+
17
+ # set header
18
+ st.header('Please upload an image')
19
+
20
+ # upload file
21
+ file = st.file_uploader('', type=['png', 'jpg', 'jpeg'])
22
+
23
+ # load model
24
+ cfg = get_cfg()
25
+ cfg.merge_from_file(model_zoo.get_config_file('COCO-Detection/retinanet_R_101_FPN_3x.yaml'))
26
+ cfg.MODEL.WEIGHTS = 'model_0001999.pth'
27
+ cfg.MODEL.DEVICE = 'cpu'
28
+
29
+ predictor = DefaultPredictor(cfg)
30
+
31
+ # load image
32
+ if file:
33
+ image = Image.open(file).convert('RGB')
34
+
35
+ image_array = np.asarray(image)
36
+
37
+ # detect objects
38
+ outputs = predictor(image_array)
39
+
40
+ threshold = 0.5
41
+
42
+ # Display predictions
43
+ preds = outputs["instances"].pred_classes.tolist()
44
+ scores = outputs["instances"].scores.tolist()
45
+ bboxes = outputs["instances"].pred_boxes
46
+
47
+ bboxes_ = []
48
+ for j, bbox in enumerate(bboxes):
49
+ bbox = bbox.tolist()
50
+
51
+ score = scores[j]
52
+ pred = preds[j]
53
+
54
+ if score > threshold:
55
+ x1, y1, x2, y2 = [int(i) for i in bbox]
56
+ bboxes_.append([x1, y1, x2, y2])
57
+
58
+ # visualize
59
+ visualize(image, bboxes_)
model_0001999.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea65ddcf276040cb0d20480482d5d8e8964deccfc18641ce583bf5ee99f97262
3
+ size 455037799
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.23.1
2
+ Pillow==9.5.0
3
+ numpy==1.24.3
4
+ torch==2.0.0
5
+ torchvision==0.15.1
6
+ opencv-python==4.6.0.66
7
+ matplotlib==3.5.3
8
+ plotly==5.15.0
9
+ git+https://github.com/facebookresearch/detectron2.git
util.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+
3
+ import plotly.graph_objects as go
4
+ import streamlit as st
5
+
6
+
7
+ def set_background(image_file):
8
+ """
9
+ This function sets the background of a Streamlit app to an image specified by the given image file.
10
+
11
+ Parameters:
12
+ image_file (str): The path to the image file to be used as the background.
13
+
14
+ Returns:
15
+ None
16
+ """
17
+ with open(image_file, "rb") as f:
18
+ img_data = f.read()
19
+ b64_encoded = base64.b64encode(img_data).decode()
20
+ style = f"""
21
+ <style>
22
+ .stApp {{
23
+ background-image: url(data:image/png;base64,{b64_encoded});
24
+ background-size: cover;
25
+ }}
26
+ </style>
27
+ """
28
+ st.markdown(style, unsafe_allow_html=True)
29
+
30
+
31
+ def visualize(image, bboxes):
32
+ """
33
+ Visualizes the image with bounding boxes using Plotly.
34
+
35
+ Args:
36
+ image: The input image.
37
+ bboxes (list): A list of bounding boxes in the format [x1, y1, x2, y2].
38
+
39
+ """
40
+ # Get the width and height of the image
41
+ width, height = image.size
42
+
43
+ shapes = []
44
+ for bbox in bboxes:
45
+ x1, y1, x2, y2 = bbox
46
+
47
+ # Convert bounding box coordinates to the format expected by Plotly
48
+ shapes.append(dict(
49
+ type="rect",
50
+ x0=x1,
51
+ y0=height - y2,
52
+ x1=x2,
53
+ y1=height - y1,
54
+ line=dict(color='red', width=6),
55
+ ))
56
+
57
+ fig = go.Figure()
58
+
59
+ # Add the image as a layout image
60
+ fig.update_layout(
61
+ images=[dict(
62
+ source=image,
63
+ xref="x",
64
+ yref="y",
65
+ x=0,
66
+ y=height,
67
+ sizex=width,
68
+ sizey=height,
69
+ sizing="stretch"
70
+ )]
71
+ )
72
+
73
+ # Set the axis ranges and disable axis labels
74
+ fig.update_xaxes(range=[0, width], showticklabels=False)
75
+ fig.update_yaxes(scaleanchor="x",
76
+ scaleratio=1,
77
+ range=[0, width], showticklabels=False)
78
+
79
+ fig.update_layout(
80
+ height=800,
81
+ updatemenus=[
82
+ dict(
83
+ direction='left',
84
+ pad=dict(r=10, t=10),
85
+ showactive=True,
86
+ x=0.11,
87
+ xanchor="left",
88
+ y=1.1,
89
+ yanchor="top",
90
+ type="buttons",
91
+ buttons=[
92
+ dict(label="Original",
93
+ method="relayout",
94
+ args=["shapes", []]),
95
+ dict(label="Detections",
96
+ method="relayout",
97
+ args=["shapes", shapes])
98
+ ],
99
+ )
100
+ ]
101
+ )
102
+
103
+ st.plotly_chart(fig)