jeyasee commited on
Commit
f7beb5f
·
verified ·
1 Parent(s): 87ee583

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +139 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,142 @@
1
- import altair as alt
 
 
 
 
 
 
 
2
  import numpy as np
3
  import pandas as pd
4
- import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchvision
4
+ import torchvision.transforms as transforms
5
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
6
+ from torchvision.transforms import ToTensor
7
+ from PIL import Image, ImageDraw
8
+ import cv2
9
  import numpy as np
10
  import pandas as pd
11
+ import os
12
+
13
+
14
+
15
+ import tempfile
16
+ from tempfile import NamedTemporaryFile
17
+
18
+ # Create an FRCNN model instance with the same structure as the saved model
19
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(num_classes=91)
20
+
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ # Load the saved parameters into the model
24
+ model.load_state_dict(torch.load("frcnn_model.pth"))
25
+
26
+ # Define the classes for object detection
27
+ classes = [
28
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
29
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
30
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
31
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A',
32
+ 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
33
+ 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
34
+ 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork',
35
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
36
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
37
+ 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A',
38
+ 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
39
+ 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase',
40
+ 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
41
+ ]
42
+
43
+ # Set the threshold for object detection. It is IoU (Intersection over Union)
44
+ threshold = 0.5
45
+
46
+ st.title(""" Image Object Detection """)
47
+
48
+ # st.subheader("Prediction of Object Detection")
49
+
50
+ st.write(""" The Faster R-CNN (Region-based Convolutional Neural Network) is a cutting-edge object detection model that combines deep
51
+ learning with region proposal networks to achieve highly accurate object detection in images.
52
+ It is trained on a large dataset of images and can detect a wide range of objects with high Precision and Recall.
53
+ The model is based on the ResNet-50 architecture, which allows it to capture complex visual features from the input image.
54
+ It uses a two-stage approach, first proposing regions of interest (RoIs) in the image and then classifying and refining the
55
+ object boundaries within these RoIs. This approach makes it extremely efficient and accurate in detecting multiple objects
56
+ in a single image.
57
+ """)
58
+
59
+
60
+ # images = ["test2.jpg","img7.jpg","img20.jpg","img23.jpg"]
61
+ # with st.sidebar:
62
+ # st.write("Choose an image")
63
+ # selected_image = st.selectbox("Select an image", images)
64
+
65
+
66
+ images = ["test2.jpg","img7.jpg","img20.jpg","img23.jpg"]
67
+ with st.sidebar:
68
+ st.write("Choose an image")
69
+ st.image(images)
70
+
71
+
72
+ # define the function to perform object detection on an image
73
+ def detect_objects(image_path):
74
+ # load the image
75
+ image = Image.open(image_path).convert('RGB')
76
+
77
+ # convert the image to a tensor
78
+ image_tensor = ToTensor()(image).to(device)
79
+
80
+ # run the image through the model to get the predictions
81
+ model.eval()
82
+ with torch.no_grad():
83
+ predictions = model([image_tensor])
84
+
85
+ # filter out the predictions below the threshold
86
+ scores = predictions[0]['scores'].cpu().numpy()
87
+ boxes = predictions[0]['boxes'].cpu().numpy()
88
+ labels = predictions[0]['labels'].cpu().numpy()
89
+ mask = scores > threshold
90
+ scores = scores[mask]
91
+ boxes = boxes[mask]
92
+ labels = labels[mask]
93
+
94
+ # create a new image with the predicted objects outlined in rectangles
95
+ draw = ImageDraw.Draw(image)
96
+ for box, label in zip(boxes, labels):
97
+
98
+ # draw the rectangle around the object
99
+ draw.rectangle([(box[0], box[1]), (box[2], box[3])], outline='red')
100
+
101
+ # write the object class above the rectangle
102
+ class_name = classes[label]
103
+ draw.text((box[0], box[1]), class_name, fill='yellow')
104
+
105
+ # show the image
106
+ st.write("Obects detected in the image are: ")
107
+ st.image(image, use_column_width=True)
108
+ # st.image.show()
109
+
110
+
111
+ file = st.file_uploader('Upload an Image', type=(["jpeg", "jpg", "png"]))
112
+
113
+ if file is None:
114
+ st.write("Please upload an image file")
115
+ else:
116
+ image = Image.open(file)
117
+ st.write("Input Image")
118
+ st.image(image, use_column_width=True)
119
+ with NamedTemporaryFile(dir='.', suffix='.jpeg') as f: # this line gives error and only accepts .jpeg and so used above snippet
120
+ f.write(file.getbuffer()) # which will accepts all formats of images.
121
+ # your_function_which_takes_a_path(f.name)
122
+ detect_objects(f.name)
123
+
124
+ # if file is None:
125
+ # st.write("Please upload an image file")
126
+ # else:
127
+ # image = Image.open(file)
128
+ # st.write("Input Image")
129
+ # st.image(image, use_column_width=True)
130
+ # with NamedTemporaryFile(dir='.', suffix='.' + file.name.split('.')[-1]) as f:
131
+ # f.write(file.getbuffer())
132
+ # # your_function_which_takes_a_path(f.name)
133
+ # detect_objects(f.name)
134
+
135
 
136
+ st.write(""" This Streamlit app provides a user-friendly interface for uploading an image and visualizing the output of the Faster R-CNN
137
+ model. It displays the uploaded image along with the predicted objects highlighted with bounding box overlays. The app allows
138
+ users to explore the detected objects in the image, providing valuable insights and understanding of the model's predictions.
139
+ It can be used for a wide range of applications, such as object recognition, image analysis, and visual storytelling.
140
+ Whether it's identifying objects in real-world images or understanding the capabilities of state-of-the-art object detection
141
+ models, this Streamlit app powered by Faster R-CNN is a powerful tool for computer vision tasks.
142
+ """)