youl's picture
Update app.py
21b130c
raw
history blame contribute delete
No virus
3.98 kB
import gradio as gr
import torch
import cv2
import os
import torch.nn as nn
import numpy as np
import torchvision
from torchvision.ops import box_iou
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from timeit import default_timer as timer
from typing import Tuple, Dict
# apply nms algorithm
def apply_nms(orig_prediction, iou_thresh=0.3):
# torchvision returns the indices of the bboxes to keep
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
return final_prediction
# Draw the bounding box
def plot_img_bbox(img, target):
h,w,c = img.shape
for box in (target['boxes']):
xmin, ymin, xmax, ymax = int((box[0].cpu()/1024)*w), int((box[1].cpu()/1024)*h), int((box[2].cpu()/1024)*w),int((box[3].cpu()/1024)*h)
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
label = "palm"
# Add the label and confidence score
label = f'{label}'
cv2.putText(img, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
# Display the image with detections
#filename = 'pred.jpg'
#cv2.imwrite(filename, img)
return img
# transform image
test_transforms = A.Compose([
A.Resize(height=1024, width=1024, always_apply=True),
A.Normalize(always_apply=True),
ToTensorV2(always_apply=True),])
# select device (whether GPU or CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# model loading
model = torch.load('pickel.pth',map_location=torch.device('cpu'))
model = model.to(device)
#-> Tuple[Dict, float]
def predict(img) :
# Start a timer
start_time = timer()
# Transform the target image and add a batch dimension
#image_transformed = test_transforms()
transformed = test_transforms(image= np.array(img))
image_transformed = transformed["image"]
image_transformed = image_transformed.unsqueeze(0)
image_transformed = image_transformed.to(device)
# inference
model.eval()
with torch.no_grad():
predictions = model(image_transformed)[0]
nms_prediction = apply_nms(predictions, iou_thresh=0.1)
pred = plot_img_bbox(np.array(img), nms_prediction)
#pred = np.array(Image.open("pred.jpg"))
word = "Number of palm trees detected : "+str(len(nms_prediction["boxes"]))
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
return pred,word
image = gr.components.Image()
out_im = gr.components.Image()
out_lab = gr.components.Label()
### 4. Gradio app ###
# Create title, description and article strings
title = "🌴Palm trees detection🌴"
description = "Faster r-cnn model to detect oil palm trees in drones images."
article = "Created by data354."
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
#[gr.Label(label="Predictions"), # what are the outputs?
#gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
# Create examples list from "examples/" directory
# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to output
inputs= image, #gr.Image(type="pil"), # what are the inputs?
outputs=[out_im,out_lab],
examples=example_list,
title=title,
description=description,
article=article
)
# Launch the demo!
demo.launch(debug = False)