Palm_Detection_V3 / functions.py
youl's picture
functions update
becf64c
raw
history blame contribute delete
No virus
5.76 kB
import torch
import cv2
import os
import torch.nn as nn
import numpy as np
import torchvision
from torchvision.ops import box_iou
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import tqdm
import gc
from time import sleep
import shutil
from timeit import default_timer as timer
from typing import Tuple, Dict
import warnings
warnings.filterwarnings('ignore')
# apply nms algorithm
def apply_nms(orig_prediction, iou_thresh=0.3):
# torchvision returns the indices of the bboxes to keep
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
return final_prediction
def apply_nms2(orig_prediction, iou_thresh=0.3):
# torchvision returns the indices of the bboxes to keep
preds = []
for prediction in orig_prediction:
keep = torchvision.ops.nms(prediction['boxes'], prediction['scores'], iou_thresh)
final_prediction = prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
preds.append(final_prediction)
return preds
# Draw the bounding box
def plot_img_bbox(img, target):
h,w,c = img.shape
for box in (target['boxes']):
xmin, ymin, xmax, ymax = int((box[0].cpu()/1024)*w), int((box[1].cpu()/1024)*h), int((box[2].cpu()/1024)*w),int((box[3].cpu()/1024)*h)
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
label = "palm"
# Add the label and confidence score
label = f'{label}'
cv2.putText(img, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
# Display the image with detections
#filename = 'pred.jpg'
#cv2.imwrite(filename, img)
return img
def crop(image,size=1024):
#input = os.path.join(path,image)
#img = cv2.imread(input)
img = image.copy()
H, W,_ = img.shape
h = (H//size)
w = (W//size)
H1 = h*size
W1 = w*size
os.makedirs("images", exist_ok=True)
images = []
#images_truth = []
locations = []
if H1 < H :
chevauche_h = H-H1
rest_h = 1024-chevauche_h
val_h = H1-rest_h
H2 = [x for x in range(0,H1,size)] +[val_h]
else :
H2 = [x for x in range(0,H1,size)]
if W1 <W :
chevauche_w = W-W1
rest_w = 1024-chevauche_w
val_w = W1-rest_w
W2 = [x for x in range(0,W1,size)] +[val_w]
else:
W2 = [x for x in range(0,W1,size)]
for i in H2:
for j in W2:
crop_img = img[i:i+size, j:j+size,:]
name = "img_"+str(i)+"_"+str(j)+".png"
## csv file creation
location = [i,i+size,j,j+size]
locations.append(location)
cv2.imwrite(os.path.join("images",name),crop_img)
del crop_img
gc.collect()
#sleep(2)
del H,H1,H2,W,W1,W2,h,w
gc.collect()
sleep(1)
np.save("locations.npy",np.array(locations))
def inference(image,locations,model,test_transforms,device):
n = 0
os.makedirs("labels", exist_ok=True)
for i,location in enumerate(locations):
name = "img_"+str(location[0])+"_"+str(location[2])+".png"
path = os.path.join("images",name)
imgs = np.array(cv2.imread(path))
transformed = test_transforms(image= imgs)
image_transformed = transformed["image"]
image_transformed = image_transformed.unsqueeze(0)
image_transformed = image_transformed.to(device)
model.eval()
with torch.no_grad():
predictions = model(image_transformed)
del imgs,name,path,transformed,image_transformed
gc.collect()
sleep(1)
nms_prediction = apply_nms2(predictions, iou_thresh=0.1)
img = image[location[0]:location[1],location[2]:location[3],:]
n = n+len(nms_prediction[0]['boxes'])
for box in (nms_prediction[0]['boxes']):
xmin, ymin, xmax, ymax = int(box[0].cpu()), int(box[1].cpu()), int(box[2].cpu()),int(box[3].cpu())
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)
label = "palm"
# Add the label and confidence score
label = f'{label}'
cv2.putText(img, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
del label
#empty_image[location[0]:location[1],location[2]:location[3],:] = img
label_name = "lab_"+str(location[0])+"_"+str(location[2])+".png"
cv2.imwrite(os.path.join("labels",label_name),img)
del label_name,img,nms_prediction,predictions
gc.collect()
sleep(1)
return n
def create_new_ortho(locations,empty_image):
for i,location in tqdm(enumerate(locations),total=len(locations)):
name = "lab_"+str(location[0])+"_"+str(location[2])+".png"
path = os.path.join("labels",name)
img = np.array(cv2.imread(path))
empty_image[location[0]:location[1],location[2]:location[3],:] = img
if i%300==0:
cv2.imwrite("img.png",empty_image)
del img,name,path,empty_image
gc.collect()
#sleep(1)
empty_image = np.array(cv2.imread("img.png"))
cv2.imwrite("img.png",empty_image)
empty_image = np.array(cv2.imread("img.png"))
return empty_image