File size: 3,982 Bytes
a0abb52
 
 
 
 
 
b0035fe
a0abb52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84d9a49
a0abb52
84d9a49
a0abb52
 
 
 
 
 
 
b5bf50b
 
 
a0abb52
 
 
 
 
 
 
 
 
 
 
 
 
 
3b9da9b
 
a0abb52
 
 
 
 
717fe15
 
bbff62d
a0abb52
 
 
 
 
 
 
 
 
 
b5bf50b
a0abb52
b5bf50b
a98ab4d
a0abb52
 
 
 
 
a98ab4d
b5bf50b
 
 
a98ab4d
a0abb52
 
 
 
 
 
 
 
 
b5bf50b
 
 
a0abb52
 
081def8
21b130c
a0abb52
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import torch
import cv2
import os
import torch.nn as nn
import numpy as np
import torchvision
from torchvision.ops import box_iou
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

from timeit import default_timer as timer
from typing import Tuple, Dict

# apply nms algorithm
def apply_nms(orig_prediction, iou_thresh=0.3):
        # torchvision returns the indices of the bboxes to keep
        keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
        final_prediction = orig_prediction
        final_prediction['boxes'] = final_prediction['boxes'][keep]
        final_prediction['scores'] = final_prediction['scores'][keep]
        final_prediction['labels'] = final_prediction['labels'][keep]

        return final_prediction

# Draw the bounding box
def plot_img_bbox(img, target):
        h,w,c = img.shape
        for box in (target['boxes']):
            xmin, ymin, xmax, ymax  = int((box[0].cpu()/1024)*w), int((box[1].cpu()/1024)*h), int((box[2].cpu()/1024)*w),int((box[3].cpu()/1024)*h)
            cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
            label = "palm"
            # Add the label and confidence score
            label = f'{label}'
            cv2.putText(img, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)

        # Display the image with detections
        #filename = 'pred.jpg'
        #cv2.imwrite(filename, img)
        return img

# transform image
test_transforms = A.Compose([
        A.Resize(height=1024, width=1024, always_apply=True),
        A.Normalize(always_apply=True),
        ToTensorV2(always_apply=True),])

# select device (whether GPU or CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# model loading
model = torch.load('pickel.pth',map_location=torch.device('cpu'))
model = model.to(device)

#-> Tuple[Dict, float]
def predict(img) :
    
    # Start a timer
    start_time = timer()
    
    # Transform the target image and add a batch dimension
    #image_transformed = test_transforms()
    transformed = test_transforms(image= np.array(img))
    image_transformed = transformed["image"]
    image_transformed = image_transformed.unsqueeze(0)
    image_transformed = image_transformed.to(device)

    # inference
    model.eval()
    with torch.no_grad():
        predictions = model(image_transformed)[0]

    nms_prediction = apply_nms(predictions, iou_thresh=0.1)

    pred = plot_img_bbox(np.array(img), nms_prediction)

    #pred = np.array(Image.open("pred.jpg"))
    word = "Number of palm trees detected : "+str(len(nms_prediction["boxes"]))
    
    # Calculate the prediction time
    pred_time = round(timer() - start_time, 5)
    
    # Return the prediction dictionary and prediction time 
    return pred,word
    
image = gr.components.Image()
out_im = gr.components.Image()
out_lab = gr.components.Label()

### 4. Gradio app ###
# Create title, description and article strings
title = "🌴Palm trees detection🌴"
description = "Faster r-cnn model to detect oil palm trees in drones images."
article = "Created by data354."

# Create examples list from "examples/" directory
example_list = [["examples/" + example] for example in os.listdir("examples")]
#[gr.Label(label="Predictions"), # what are the outputs?
#gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
# Create examples list from "examples/" directory
# Create the Gradio demo
demo = gr.Interface(fn=predict, # mapping function from input to output
                    inputs= image, #gr.Image(type="pil"), # what are the inputs?
                    outputs=[out_im,out_lab],
                    examples=example_list, 
                    title=title,
                    description=description,
                    article=article
                   )
# Launch the demo!
demo.launch(debug = False)