qubvel-hf's picture
qubvel-hf HF staff
Init project
c509e76
raw
history blame contribute delete
No virus
6.05 kB
import torch
import argparse
import numpy as np
import torch.nn.functional as F
import glob
import cv2
from tqdm import tqdm
import time
import os
from model.deep_lab_model.deeplab import *
from MBD import mask_base_dewarper
import time
from utils import cvimg2torch,torch2cvimg
def net1_net2_infer(model,img_paths,args):
### validate on the real datasets
seg_model=model
seg_model.eval()
for img_path in tqdm(img_paths):
if os.path.exists(img_path.replace('_origin','_capture')):
continue
t1 = time.time()
### segmentation mask predict
img_org = cv2.imread(img_path)
h_org,w_org = img_org.shape[:2]
img = cv2.resize(img_org,(448, 448))
img = cv2.GaussianBlur(img,(15,15),0,0)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = cvimg2torch(img)
with torch.no_grad():
pred = seg_model(img.cuda())
mask_pred = pred[:,0,:,:].unsqueeze(1)
mask_pred = F.interpolate(mask_pred,(h_org,w_org))
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
mask_pred = (mask_pred*255).astype(np.uint8)
kernel = np.ones((3,3))
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
mask_pred[mask_pred>100] = 255
mask_pred[mask_pred<100] = 0
### tps transform base on the mask
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
try:
dewarp, grid = mask_base_dewarper(img_org,mask_pred)
except:
print('fail')
grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
grid = grid[0].numpy()
# cv2.imshow('in',cv2.resize(img_org,(512,512)))
# cv2.imshow('out',cv2.resize(dewarp,(512,512)))
# cv2.waitKey(0)
cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
grid0 = cv2.resize(grid[:,:,0],(128,128))
grid1 = cv2.resize(grid[:,:,1],(128,128))
grid = np.stack((grid0,grid1),axis=-1)
np.save(img_path.replace('_origin','_grid1'),grid)
def net1_net2_infer_single_im(img,model_path):
seg_model = DeepLab(num_classes=1,
backbone='resnet',
output_stride=16,
sync_bn=None,
freeze_bn=False)
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
seg_model.cuda()
checkpoint = torch.load(model_path)
seg_model.load_state_dict(checkpoint['model_state'])
### validate on the real datasets
seg_model.eval()
### segmentation mask predict
img_org = img
h_org,w_org = img_org.shape[:2]
img = cv2.resize(img_org,(448, 448))
img = cv2.GaussianBlur(img,(15,15),0,0)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = cvimg2torch(img)
with torch.no_grad():
# from torchtoolbox.tools import summary
# print(summary(seg_model,torch.rand((1, 3, 448, 448)).cuda())) 59.4M 135.6G
pred = seg_model(img.cuda())
mask_pred = pred[:,0,:,:].unsqueeze(1)
mask_pred = F.interpolate(mask_pred,(h_org,w_org))
mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
mask_pred = (mask_pred*255).astype(np.uint8)
kernel = np.ones((3,3))
mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
mask_pred[mask_pred>100] = 255
mask_pred[mask_pred<100] = 0
### tps transform base on the mask
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
# try:
# dewarp, grid = mask_base_dewarper(img_org,mask_pred)
# except:
# print('fail')
# grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
# grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
# dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
# grid = grid[0].numpy()
# cv2.imshow('in',cv2.resize(img_org,(512,512)))
# cv2.imshow('out',cv2.resize(dewarp,(512,512)))
# cv2.waitKey(0)
# cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
# cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
# grid0 = cv2.resize(grid[:,:,0],(128,128))
# grid1 = cv2.resize(grid[:,:,1],(128,128))
# grid = np.stack((grid0,grid1),axis=-1)
# np.save(img_path.replace('_origin','_grid1'),grid)
return mask_pred
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data')
parser.add_argument('--img_rows', nargs='?', type=int, default=448,
help='Height of the input image')
parser.add_argument('--img_cols', nargs='?', type=int, default=448,
help='Width of the input image')
parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl',
help='Path to previous saved model to restart from')
args = parser.parse_args()
seg_model = DeepLab(num_classes=1,
backbone='resnet',
output_stride=16,
sync_bn=None,
freeze_bn=False)
seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
seg_model.cuda()
checkpoint = torch.load(args.seg_model_path)
seg_model.load_state_dict(checkpoint['model_state'])
im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*'))
net1_net2_infer(seg_model,im_paths,args)