saritha5's picture
Update app.py
a712e98
raw
history blame
4.33 kB
from detecto import core, utils, visualize
from detecto.visualize import show_labeled_image, plot_prediction_grid
from torchvision import transforms
import matplotlib.pyplot as plt
from tensorflow.keras.utils import img_to_array
import numpy as np
import warnings
from PIL import Image
import streamlit as st
warnings.filterwarnings("ignore", category=UserWarning)
from tempfile import NamedTemporaryFile
import cv2
import matplotlib.patches as patches
import torch
import matplotlib.image as mpimg
import os
from detecto.utils import reverse_normalize, normalize_transform, _is_iterable
from torchvision import transforms
MODEL_PATH = "SD_model_weights.pth"
IMAGE_PATH = "img1.jpeg"
model = core.Model.load(MODEL_PATH, ['cross_arm','pole','tag'])
#warnings.warn(msg)
st.title("Object Detection")
image = utils.read_image(IMAGE_PATH)
predictions = model.predict(image)
labels, boxes, scores = predictions
images = ["img1.jpeg","img4.jpeg","img5.jpeg","img6.jpeg"]
with st.sidebar:
st.write("choose an image")
st.image(images)
def detect_object(IMAGE_PATH):
image = utils.read_image(IMAGE_PATH)
# predictions = model.predict(image)
# labels, boxes, scores = predictions
thresh=0.2
filtered_indices=np.where(scores>thresh)
filtered_scores=scores[filtered_indices]
filtered_boxes=boxes[filtered_indices]
num_list = filtered_indices[0].tolist()
filtered_labels = [labels[i] for i in num_list]
show_labeled_image(image, filtered_boxes, filtered_labels)
fig1 = show_image(image,filtered_boxes,filtered_labels)
st.write("Object Detected Image is")
st.image(fig1)
#img_array = img_to_array(img)
def show_image(image, boxes, labels=None):
"""Show the image along with the specified boxes around detected objects.
Also displays each box's label if a list of labels is provided.
:param image: The image to plot. If the image is a normalized
torch.Tensor object, it will automatically be reverse-normalized
and converted to a PIL image for plotting.
:type image: numpy.ndarray or torch.Tensor
:param boxes: A torch tensor of size (N, 4) where N is the number
of boxes to plot, or simply size 4 if N is 1.
:type boxes: torch.Tensor
:param labels: (Optional) A list of size N giving the labels of
each box (labels[i] corresponds to boxes[i]). Defaults to None.
:type labels: torch.Tensor or None
**Example**::
>>> from detecto.core import Model
>>> from detecto.utils import read_image
>>> from detecto.visualize import show_labeled_image
>>> model = Model.load('model_weights.pth', ['tick', 'gate'])
>>> image = read_image('image.jpg')
>>> labels, boxes, scores = model.predict(image)
>>> show_labeled_image(image, boxes, labels)
"""
fig, ax = plt.subplots(1)
# If the image is already a tensor, convert it back to a PILImage
# and reverse normalize it
if isinstance(image, torch.Tensor):
image = reverse_normalize(image)
image = transforms.ToPILImage()(image)
ax.imshow(image)
# Show a single box or multiple if provided
if boxes.ndim == 1:
boxes = boxes.view(1, 4)
if labels is not None and not _is_iterable(labels):
labels = [labels]
# Plot each box
for i in range(2):
box = boxes[i]
width, height = (box[2] - box[0]).item(), (box[3] - box[1]).item()
initial_pos = (box[0].item(), box[1].item())
rect = patches.Rectangle(initial_pos, width, height, linewidth=1,
edgecolor='r', facecolor='none')
if labels:
ax.text(box[0] + 5, box[1] - 5, '{}'.format(labels[i]), color='red')
ax.add_patch(rect)
cp = os.path.abspath(os.getcwd()) + '/foo.png'
plt.savefig(cp)
plt.close(fig)
return cp
#print(type(plt
file = st.file_uploader('Upload an Image',type=(["jpeg","jpg","png"]))
if file is None:
st.write("Please upload an image file")
else:
image= Image.open(file)
st.write("Input Image")
st.image(image,use_column_width = True)
with NamedTemporaryFile(dir='.', suffix='.jpeg') as f:
f.write(file.getbuffer())
#your_function_which_takes_a_path(f.name)
detect_object(f.name)