Realcat
add: COTR(https://github.com/ubc-vision/COTR)
10dcc2e
raw history blame
No virus
3.05 kB
'''
COTR demo for homography estimation
'''
import argparse
import os
import time
import cv2
import numpy as np
import torch
import imageio
import matplotlib.pyplot as plt
from COTR.utils import utils, debug_utils
from COTR.models import build_model
from COTR.options.options import *
from COTR.options.options_utils import *
from COTR.inference.inference_helper import triangulate_corr
from COTR.inference.sparse_engine import SparseEngine
utils.fix_randomness(0)
torch.set_grad_enabled(False)
def main(opt):
model = build_model(opt)
model = model.cuda()
weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
utils.safe_load_weights(model, weights)
model = model.eval()
img_a = imageio.imread('./sample_data/imgs/paint_1.JPG', pilmode='RGB')
img_b = imageio.imread('./sample_data/imgs/paint_2.jpg', pilmode='RGB')
rep_img = imageio.imread('./sample_data/imgs/Meisje_met_de_parel.jpg', pilmode='RGB')
rep_mask = np.ones(rep_img.shape[:2])
lu_corner = [932, 1025]
ru_corner = [2469, 901]
lb_corner = [908, 2927]
rb_corner = [2436, 3080]
queries = np.array([lu_corner, ru_corner, lb_corner, rb_corner]).astype(np.float32)
rep_coord = np.array([[0, 0], [rep_img.shape[1], 0], [0, rep_img.shape[0]], [rep_img.shape[1], rep_img.shape[0]]]).astype(np.float32)
engine = SparseEngine(model, 32, mode='stretching')
corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=True)
T = cv2.getPerspectiveTransform(rep_coord, corrs[:, 2:].astype(np.float32))
vmask = cv2.warpPerspective(rep_mask, T, (img_b.shape[1], img_b.shape[0])) > 0
warped = cv2.warpPerspective(rep_img, T, (img_b.shape[1], img_b.shape[0]))
out = warped * vmask[..., None] + img_b * (~vmask[..., None])
f, axarr = plt.subplots(1, 4)
axarr[0].imshow(rep_img)
axarr[0].title.set_text('Virtual Paint')
axarr[0].axis('off')
axarr[1].imshow(img_a)
axarr[1].title.set_text('Annotated Frame')
axarr[1].axis('off')
axarr[2].imshow(img_b)
axarr[2].title.set_text('Target Frame')
axarr[2].axis('off')
axarr[3].imshow(out)
axarr[3].title.set_text('Overlay')
axarr[3].axis('off')
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
set_COTR_arguments(parser)
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
opt = parser.parse_args()
opt.command = ' '.join(sys.argv)
layer_2_channels = {'layer1': 256,
'layer2': 512,
'layer3': 1024,
'layer4': 2048, }
opt.dim_feedforward = layer_2_channels[opt.layer]
if opt.load_weights:
opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
print_opt(opt)
main(opt)