Test1 / app.py
Rattpon's picture
app.py
62fd502
raw
history blame contribute delete
No virus
4.91 kB
import cv2
import numpy as np
from PIL import Image
import streamlit as st
import tensorflow as tf
from tensorflow.keras.models import load_model
# most of this code has been obtained from Datature's prediction script
# https://github.com/datature/resources/blob/main/scripts/bounding_box/prediction.py
st.set_option('deprecation.showfileUploaderEncoding', False)
@st.cache(allow_output_mutation=True)
def load_model():
return tf.saved_model.load('./saved_model')
def load_label_map(label_map_path):
"""
Reads label map in the format of .pbtxt and parse into dictionary
Args:
label_map_path: the file path to the label_map
Returns:
dictionary with the format of {label_index: {'id': label_index, 'name': label_name}}
"""
label_map = {}
with open(label_map_path, "r") as label_file:
for line in label_file:
if "id" in line:
label_index = int(line.split(":")[-1])
label_name = next(label_file).split(":")[-1].strip().strip('"')
label_map[label_index] = {"id": label_index, "name": label_name}
return label_map
def predict_class(image, model):
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, [150, 150])
image = np.expand_dims(image, axis = 0)
return model.predict(image)
def plot_boxes_on_img(color_map, classes, bboxes, image_origi, origi_shape):
for idx, each_bbox in enumerate(bboxes):
color = color_map[classes[idx]]
## Draw bounding box
cv2.rectangle(
image_origi,
(int(each_bbox[1] * origi_shape[1]),
int(each_bbox[0] * origi_shape[0]),),
(int(each_bbox[3] * origi_shape[1]),
int(each_bbox[2] * origi_shape[0]),),
color,
2,
)
## Draw label background
cv2.rectangle(
image_origi,
(int(each_bbox[1] * origi_shape[1]),
int(each_bbox[2] * origi_shape[0]),),
(int(each_bbox[3] * origi_shape[1]),
int(each_bbox[2] * origi_shape[0] + 15),),
color,
-1,
)
## Insert label class & score
cv2.putText(
image_origi,
"Class: {}, Score: {}".format(
str(category_index[classes[idx]]["name"]),
str(round(scores[idx], 2)),
),
(int(each_bbox[1] * origi_shape[1]),
int(each_bbox[2] * origi_shape[0] + 10),),
cv2.FONT_HERSHEY_SIMPLEX,
0.3,
(0, 0, 0),
1,
cv2.LINE_AA,
)
return image_origi
# Webpage code starts here
#TODO change this
st.title('YOUR PROJECT NAME')
st.text('made by XXX')
st.markdown('## Description about your project')
with st.spinner('Model is being loaded...'):
model = load_model()
# ask user to upload an image
file = st.file_uploader("Upload image", type=["jpg", "png"])
if file is None:
st.text('Waiting for upload...')
else:
st.text('Running inference...')
# open image
test_image = Image.open(file).convert("RGB")
origi_shape = np.asarray(test_image).shape
# resize image to default shape
default_shape = 320
image_resized = np.array(test_image.resize((default_shape, default_shape)))
## Load color map
category_index = load_label_map("./label_map.pbtxt")
# TODO Add more colors if there are more classes
# color of each label. check label_map.pbtxt to check the index for each class
color_map = {
1: [255, 0, 0], # bad -> red
2: [0, 255, 0] # good -> green
}
## The model input needs to be a tensor
input_tensor = tf.convert_to_tensor(image_resized)
## The model expects a batch of images, so add an axis with `tf.newaxis`.
input_tensor = input_tensor[tf.newaxis, ...]
## Feed image into model and obtain output
detections_output = model(input_tensor)
num_detections = int(detections_output.pop("num_detections"))
detections = {key: value[0, :num_detections].numpy() for key, value in detections_output.items()}
detections["num_detections"] = num_detections
## Filter out predictions below threshold
# if threshold is higher, there will be fewer predictions
# TODO change this number to see how the predictions change
confidence_threshold = 0.8
indexes = np.where(detections["detection_scores"] > confidence_threshold)
## Extract predicted bounding boxes
bboxes = detections["detection_boxes"][indexes]
# there are no predicted boxes
if len(bboxes) == 0:
st.error('No boxes predicted')
# there are predicted boxes
else:
st.success('Boxes predicted')
classes = detections["detection_classes"][indexes].astype(np.int64)
scores = detections["detection_scores"][indexes]
# plot boxes and labels on image
image_origi = np.array(Image.fromarray(image_resized).resize((origi_shape[1], origi_shape[0])))
image_origi = plot_boxes_on_img(color_map, classes, bboxes, image_origi, origi_shape)
# show image in web page
st.image(Image.fromarray(image_origi), caption="Image with predictions", width=400)
st.markdown("### Predicted boxes")
for idx in range(len((bboxes))):
st.markdown(f"* Class: {str(category_index[classes[idx]]['name'])}, confidence score: {str(round(scores[idx], 2))}")