Realcat
add: COTR(https://github.com/ubc-vision/COTR)
10dcc2e
raw
history blame
2.3 kB
'''
COTR demo for human face
We use an off-the-shelf face landmarks detector: https://github.com/1adrianb/face-alignment
'''
import argparse
import os
import time
import cv2
import numpy as np
import torch
import imageio
import matplotlib.pyplot as plt
from COTR.utils import utils, debug_utils
from COTR.models import build_model
from COTR.options.options import *
from COTR.options.options_utils import *
from COTR.inference.inference_helper import triangulate_corr
from COTR.inference.sparse_engine import SparseEngine
utils.fix_randomness(0)
torch.set_grad_enabled(False)
def main(opt):
model = build_model(opt)
model = model.cuda()
weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
utils.safe_load_weights(model, weights)
model = model.eval()
img_a = imageio.imread('./sample_data/imgs/face_1.png', pilmode='RGB')
img_b = imageio.imread('./sample_data/imgs/face_2.png', pilmode='RGB')
queries = np.load('./sample_data/face_landmarks.npy')[0]
engine = SparseEngine(model, 32, mode='stretching')
corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=False)
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(img_a)
axarr[0].scatter(*queries.T, s=1)
axarr[0].title.set_text('Reference Face')
axarr[0].axis('off')
axarr[1].imshow(img_b)
axarr[1].scatter(*corrs[:, 2:].T, s=1)
axarr[1].title.set_text('Target Face')
axarr[1].axis('off')
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
set_COTR_arguments(parser)
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
opt = parser.parse_args()
opt.command = ' '.join(sys.argv)
layer_2_channels = {'layer1': 256,
'layer2': 512,
'layer3': 1024,
'layer4': 2048, }
opt.dim_feedforward = layer_2_channels[opt.layer]
if opt.load_weights:
opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
print_opt(opt)
main(opt)