import cv2 import torch from model import U2NET from torch.autograd import Variable import numpy as np from glob import glob import os def detect_single_face(face_cascade,img): # Convert into grayscale gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # Detect faces faces = face_cascade.detectMultiScale(gray, 1.1, 4) if(len(faces)==0): print("Warming: no face detection, the portrait u2net will run on the whole image!") return None # filter to keep the largest face wh = 0 idx = 0 for i in range(0,len(faces)): (x,y,w,h) = faces[i] if(whwidth): r = right-width right = width tpad = int(float(h)*0.6) top = y - tpad if(top<0): t = tpad-y top = 0 bpad = int(float(h)*0.2) bottom = y+h+bpad if(bottom>height): b = bottom-height bottom = height im_face = img[top:bottom,left:right] if(len(im_face.shape)==2): im_face = np.repeat(im_face[:,:,np.newaxis],(1,1,3)) im_face = np.pad(im_face,((t,b),(l,r),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255))) # pad to achieve image with square shape for avoding face deformation after resizing hf,wf = im_face.shape[0:2] if(hf-2>wf): wfp = int((hf-wf)/2) im_face = np.pad(im_face,((0,0),(wfp,wfp),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255))) elif(wf-2>hf): hfp = int((wf-hf)/2) im_face = np.pad(im_face,((hfp,hfp),(0,0),(0,0)),mode='constant',constant_values=((255,255),(255,255),(255,255))) # resize to have 512x512 resolution im_face = cv2.resize(im_face, (512,512), interpolation = cv2.INTER_AREA) return im_face def normPRED(d): ma = torch.max(d) mi = torch.min(d) dn = (d-mi)/(ma-mi) return dn def inference(net,input): # normalize the input tmpImg = np.zeros((input.shape[0],input.shape[1],3)) input = input/np.max(input) tmpImg[:,:,0] = (input[:,:,2]-0.406)/0.225 tmpImg[:,:,1] = (input[:,:,1]-0.456)/0.224 tmpImg[:,:,2] = (input[:,:,0]-0.485)/0.229 # convert BGR to RGB tmpImg = tmpImg.transpose((2, 0, 1)) tmpImg = tmpImg[np.newaxis,:,:,:] tmpImg = torch.from_numpy(tmpImg) # convert numpy array to torch tensor tmpImg = tmpImg.type(torch.FloatTensor) if torch.cuda.is_available(): tmpImg = Variable(tmpImg.cuda()) else: tmpImg = Variable(tmpImg) # inference d1,d2,d3,d4,d5,d6,d7= net(tmpImg) # normalization pred = 1.0 - d1[:,0,:,:] pred = normPRED(pred) # convert torch tensor to numpy array pred = pred.squeeze() pred = pred.cpu().data.numpy() del d1,d2,d3,d4,d5,d6,d7 return pred def main(): # get the image path list for inference im_list = glob('./test_data/test_portrait_images/your_portrait_im/*') print("Number of images: ",len(im_list)) # indicate the output directory out_dir = './test_data/test_portrait_images/your_portrait_results' if(not os.path.exists(out_dir)): os.mkdir(out_dir) # Load the cascade face detection model face_cascade = cv2.CascadeClassifier('./saved_models/face_detection_cv2/haarcascade_frontalface_default.xml') # u2net_portrait path model_dir = './saved_models/u2net_portrait/u2net_portrait.pth' # load u2net_portrait model net = U2NET(3,1) net.load_state_dict(torch.load(model_dir)) if torch.cuda.is_available(): net.cuda() net.eval() # do the inference one-by-one for i in range(0,len(im_list)): print("--------------------------") print("inferencing ", i, "/", len(im_list), im_list[i]) # load each image img = cv2.imread(im_list[i]) height,width = img.shape[0:2] face = detect_single_face(face_cascade,img) im_face = crop_face(img, face) im_portrait = inference(net,im_face) # save the output cv2.imwrite(out_dir+"/"+im_list[i].split('/')[-1][0:-4]+'.png',(im_portrait*255).astype(np.uint8)) if __name__ == '__main__': main()