youl commited on
Commit
4655008
1 Parent(s): 2b0be53

application

Browse files
Files changed (2) hide show
  1. app.py +101 -0
  2. functions.py +183 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import albumentations as A
3
+ from functions import *
4
+ warnings.filterwarnings('ignore')
5
+
6
+
7
+ # transform image
8
+ test_transforms = A.Compose([
9
+ A.Resize(height=1024, width=1024, always_apply=True),
10
+ A.Normalize(always_apply=True),
11
+ ToTensorV2(always_apply=True),])
12
+
13
+ # select device (whether GPU or CPU)
14
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
15
+
16
+ # model loading
17
+ model = torch.load('pickel.pth',map_location=torch.device('cpu'))
18
+ model = model.to(device)
19
+
20
+ #-> Tuple[Dict, float]
21
+ def predict(img) :
22
+
23
+ # Start a timer
24
+ start_time = timer()
25
+ image = np.array(img)
26
+ h,w,_ = image.shape
27
+ hw = h*w
28
+
29
+ if hw < 2*1024*1024:
30
+
31
+ # Transform the target image and add a batch dimension
32
+ #image_transformed = test_transforms()
33
+ transformed = test_transforms(image= image)
34
+ image_transformed = transformed["image"]
35
+ image_transformed = image_transformed.unsqueeze(0)
36
+ image_transformed = image_transformed.to(device)
37
+
38
+ # inference
39
+ model.eval()
40
+ with torch.no_grad():
41
+ predictions = model(image_transformed)[0]
42
+
43
+ nms_prediction = apply_nms(predictions, iou_thresh=0.1)
44
+
45
+ pred = plot_img_bbox(image, nms_prediction)
46
+
47
+ #pred = np.array(Image.open("pred.jpg"))
48
+ word = "Number of palm trees detected : "+str(len(nms_prediction["boxes"]))
49
+
50
+ # Calculate the prediction time
51
+ pred_time = round(timer() - start_time, 5)
52
+
53
+ # Return the prediction dictionary and prediction time
54
+ return pred,word
55
+
56
+ else:
57
+ crop(image)
58
+ locations = np.load("locations.npy")
59
+ n = inference(image,locations,model,test_transforms,device)
60
+ #
61
+ empty_image = np.zeros(image.shape)
62
+ del image
63
+ gc.collect()
64
+ sleep(1)
65
+
66
+ word = "Number of palm trees detected : "+str(n)
67
+ pred = create_new_ortho(locations,empty_image)
68
+ # remove files and folders
69
+ os.remove("locations.npy")
70
+ shutil.rmtree("images", ignore_errors=True)
71
+ shutil.rmtree("labels", ignore_errors=True)
72
+
73
+ return pred,word
74
+
75
+
76
+ image = gr.components.Image()
77
+ out_im = gr.components.Image()
78
+ out_lab = gr.components.Label()
79
+
80
+ ### 4. Gradio app ###
81
+ # Create title, description and article strings
82
+ title = "🌴Palm trees detection🌴"
83
+ description = "Faster r-cnn model to detect oil palm trees in drones images."
84
+ article = "Created by data354."
85
+
86
+ # Create examples list from "examples/" directory
87
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
88
+ #[gr.Label(label="Predictions"), # what are the outputs?
89
+ #gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
90
+ # Create examples list from "examples/" directory
91
+ # Create the Gradio demo
92
+ demo = gr.Interface(fn=predict, # mapping function from input to output
93
+ inputs= image, #gr.Image(type="pil"), # what are the inputs?
94
+ outputs=[out_im,out_lab],
95
+ examples=example_list,
96
+ title=title,
97
+ description=description,
98
+ article=article
99
+ )
100
+ # Launch the demo!
101
+ demo.launch(debug = False)
functions.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import os
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torchvision
7
+ from torchvision.ops import box_iou
8
+ from PIL import Image
9
+ import albumentations as A
10
+ from albumentations.pytorch import ToTensorV2
11
+ import cv2
12
+ import tqdm
13
+ import gc
14
+ from time import sleep
15
+ import shutil
16
+ from timeit import default_timer as timer
17
+ from typing import Tuple, Dict
18
+ import warnings
19
+ warnings.filterwarnings('ignore')
20
+
21
+ # apply nms algorithm
22
+ def apply_nms(orig_prediction, iou_thresh=0.3):
23
+ # torchvision returns the indices of the bboxes to keep
24
+ keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
25
+ final_prediction = orig_prediction
26
+ final_prediction['boxes'] = final_prediction['boxes'][keep]
27
+ final_prediction['scores'] = final_prediction['scores'][keep]
28
+ final_prediction['labels'] = final_prediction['labels'][keep]
29
+
30
+ return final_prediction
31
+
32
+
33
+ def apply_nms2(orig_prediction, iou_thresh=0.3):
34
+ # torchvision returns the indices of the bboxes to keep
35
+ preds = []
36
+ for prediction in orig_prediction:
37
+ keep = torchvision.ops.nms(prediction['boxes'], prediction['scores'], iou_thresh)
38
+
39
+ final_prediction = prediction
40
+ final_prediction['boxes'] = final_prediction['boxes'][keep]
41
+ final_prediction['scores'] = final_prediction['scores'][keep]
42
+ final_prediction['labels'] = final_prediction['labels'][keep]
43
+ preds.append(final_prediction)
44
+
45
+ return preds
46
+
47
+ # Draw the bounding box
48
+ def plot_img_bbox(img, target):
49
+ h,w,c = img.shape
50
+ for box in (target['boxes']):
51
+ xmin, ymin, xmax, ymax = int((box[0].cpu()/1024)*w), int((box[1].cpu()/1024)*h), int((box[2].cpu()/1024)*w),int((box[3].cpu()/1024)*h)
52
+ cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
53
+ label = "palm"
54
+ # Add the label and confidence score
55
+ label = f'{label}'
56
+ cv2.putText(img, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
57
+
58
+ # Display the image with detections
59
+ #filename = 'pred.jpg'
60
+ #cv2.imwrite(filename, img)
61
+ return img
62
+
63
+ def crop(image,size=1024):
64
+ #input = os.path.join(path,image)
65
+ #img = cv2.imread(input)
66
+ img = image.copy()
67
+ H, W,_ = img.shape
68
+ h = (H//size)
69
+ w = (W//size)
70
+ H1 = h*size
71
+ W1 = w*size
72
+ os.makedirs("images", exist_ok=True)
73
+ images = []
74
+ #images_truth = []
75
+ locations = []
76
+
77
+ if H1 < H :
78
+ chevauche_h = H-H1
79
+ rest_h = 1024-chevauche_h
80
+ val_h = H1-rest_h
81
+ H2 = [x for x in range(0,H1,size)] +[val_h]
82
+ else :
83
+ H2 = [x for x in range(0,H1,size)]
84
+
85
+ if W1 <W :
86
+ chevauche_w = W-W1
87
+ rest_w = 1024-chevauche_w
88
+ val_w = W1-rest_w
89
+ W2 = [x for x in range(0,W1,size)] +[val_w]
90
+ else:
91
+ W2 = [x for x in range(0,W1,size)]
92
+
93
+ for i in H2:
94
+ for j in W2:
95
+ crop_img = img[i:i+size, j:j+size,:]
96
+ name = "img_"+str(i)+"_"+str(j)+".png"
97
+ ## csv file creation
98
+ location = [i,i+size,j,j+size]
99
+ locations.append(location)
100
+ cv2.imwrite(os.path.join("images",name),crop_img)
101
+ del crop_img
102
+ gc.collect()
103
+ #sleep(2)
104
+ del H
105
+ del H1
106
+ del H2
107
+ del W
108
+ del W1
109
+ del W2
110
+ del h
111
+ del w
112
+ gc.collect()
113
+ sleep(1)
114
+ np.save("locations.npy",np.array(locations))
115
+
116
+ def inference(image,locations,model,test_transforms,device):
117
+ n = 0
118
+ os.makedirs("labels", exist_ok=True)
119
+ for i,location in enumerate(locations):
120
+ name = "img_"+str(location[0])+"_"+str(location[2])+".png"
121
+ path = os.path.join("images",name)
122
+ imgs = np.array(cv2.imread(path))
123
+ transformed = test_transforms(image= imgs)
124
+ image_transformed = transformed["image"]
125
+ image_transformed = image_transformed.unsqueeze(0)
126
+ image_transformed = image_transformed.to(device)
127
+
128
+ model.eval()
129
+ with torch.no_grad():
130
+ predictions = model(image_transformed)
131
+
132
+ del imgs
133
+ del name
134
+ del path
135
+ del transformed
136
+ del image_transformed
137
+ gc.collect()
138
+ sleep(1)
139
+
140
+ nms_prediction = apply_nms2(predictions, iou_thresh=0.1)
141
+ img = image[location[0]:location[1],location[2]:location[3],:]
142
+ n = n+len(nms_prediction[0]['boxes'])
143
+
144
+ for box in (nms_prediction[0]['boxes']):
145
+ xmin, ymin, xmax, ymax = int(box[0].cpu()), int(box[1].cpu()), int(box[2].cpu()),int(box[3].cpu())
146
+ cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)
147
+ label = "palm"
148
+ # Add the label and confidence score
149
+ label = f'{label}'
150
+ cv2.putText(img, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
151
+ del label
152
+ #empty_image[location[0]:location[1],location[2]:location[3],:] = img
153
+ label_name = "lab_"+str(location[0])+"_"+str(location[2])+".png"
154
+ cv2.imwrite(os.path.join("labels",label_name),img)
155
+
156
+ del label_name
157
+ del img
158
+ del nms_prediction
159
+ del predictions
160
+ gc.collect()
161
+ sleep(1)
162
+
163
+ return n
164
+
165
+ def create_new_ortho(locations,empty_image):
166
+ for i,location in tqdm(enumerate(locations),total=len(locations)):
167
+ name = "lab_"+str(location[0])+"_"+str(location[2])+".png"
168
+ path = os.path.join("labels",name)
169
+ img = np.array(cv2.imread(path))
170
+ empty_image[location[0]:location[1],location[2]:location[3],:] = img
171
+ if i%300==0:
172
+ cv2.imwrite("img.png",empty_image)
173
+ del img
174
+ del name
175
+ del path
176
+ del empty_image
177
+ gc.collect()
178
+ #sleep(1)
179
+ empty_image = np.array(cv2.imread("img.png"))
180
+
181
+ cv2.imwrite("img.png",empty_image)
182
+ empty_image = np.array(cv2.imread("img.png"))
183
+ return empty_image