Spaces:
Runtime error
Runtime error
from detecto import core, utils, visualize | |
from detecto.visualize import show_labeled_image, plot_prediction_grid | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
from tensorflow.keras.utils import img_to_array | |
import numpy as np | |
import warnings | |
from PIL import Image | |
import streamlit as st | |
warnings.filterwarnings("ignore", category=UserWarning) | |
from tempfile import NamedTemporaryFile | |
MODEL_PATH = "SD_model_weights.pth" | |
IMAGE_PATH = "img1.jpeg" | |
model = core.Model.load(MODEL_PATH, ['cross_arm','pole','tag']) | |
#warnings.warn(msg) | |
st.title("Object Detection") | |
image = utils.read_image(IMAGE_PATH) | |
predictions = model.predict(image) | |
labels, boxes, scores = predictions | |
images = ["img1.jpeg","img2.jpeg","img3.jpeg","img3.jpeg"] | |
with st.sidebar: | |
st.write("choose an image") | |
st.image(images) | |
def detect_object(IMAGE_PATH): | |
image = utils.read_image(IMAGE_PATH) | |
# predictions = model.predict(image) | |
# labels, boxes, scores = predictions | |
thresh=0.2 | |
filtered_indices=np.where(scores>thresh) | |
filtered_scores=scores[filtered_indices] | |
filtered_boxes=boxes[filtered_indices] | |
num_list = filtered_indices[0].tolist() | |
filtered_labels = [labels[i] for i in num_list] | |
st.show_labeled_image(image, filtered_boxes, filtered_labels) | |
#img_array = img_to_array(img) | |
file = st.file_uploader('Upload an Image',type=(["jpeg","jpg","png"])) | |
if file is None: | |
st.write("Please upload an image file") | |
else: | |
image= Image.open(file) | |
st.image(image,use_column_width = True) | |
with NamedTemporaryFile(dir='.', suffix='.csv') as f: | |
f.write(file.getbuffer()) | |
#your_function_which_takes_a_path(f.name) | |
detect_object(f.name) | |