youl's picture
Upload 2 files
a0abb52
raw
history blame
No virus
3.63 kB
import gradio as gr
import torch
import cv2
import os
import torch.nn as nn
import numpy as np
from torchvision.ops import box_iou
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
from timeit import default_timer as timer
from typing import Tuple, Dict
# apply nms algorithm
def apply_nms(orig_prediction, iou_thresh=0.3):
# torchvision returns the indices of the bboxes to keep
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
return final_prediction
# Draw the bounding box
def plot_img_bbox(img, target):
for box in (target['boxes']):
xmin, ymin, xmax, ymax = int(box[0].cpu()), int(box[1].cpu()), int(box[2].cpu()),int(box[3].cpu())
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
label = "palm"
# Add the label and confidence score
label = f'{label}'
cv2.putText(img, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
# Display the image with detections
filename = 'pred.jpg'
cv2.imwrite(filename, img)
# transform image
test_transforms = A.Compose([
A.Resize(height=1024, width=1024, always_apply=True),
A.Normalize(always_apply=True),
ToTensorV2(always_apply=True),])
# select device (whether GPU or CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# model loading
model = torch.load('pickel.pth',map_location=torch.device('cpu'))
model = model.to(device)
def predict(img) -> Tuple[Dict, float]:
# Start a timer
start_time = timer()
# Transform the target image and add a batch dimension
image_transformed = test_transforms(np.array(img))
image_transformed = image_transformed.unsqueeze(0)
image_transformed = image_transformed.to(device)
# inference
model.eval()
with torch.no_grad():
predictions = model(image_transformed)[0]
nms_prediction = apply_nms(predictions, iou_thresh=0.1)
plot_img_bbox(img, nms_prediction)
pred = np.array(Image.open("pred.jpg"))
# Calculate the prediction time
pred_time = round(timer() - start_time, 5)
# Return the prediction dictionary and prediction time
return pred,pred_time
### 4. Gradio app ###
# Create title, description and article strings
title = "🌴Palm trees detection🌴"
description = "Faster r-cnn model to detect oil palm trees in drones images."
article = "Created by data354."
# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to output
inputs=gr.Image(type="pil"), # what are the inputs?
outputs=[gr.Label(label="Predictions"), # what are the outputs?
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
# Create examples list from "examples/" directory
examples=example_list,
title=title,
description=description,
article=article
)
# Launch the demo!
demo.launch(debug = False)