ID-Pose / src /pose_estimation.py
tokenid
fix plot stuck
146da98
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from datetime import datetime
from .ldm.util import load_and_preprocess, instantiate_from_config
from .pose_funcs import probe_pose, find_optimal_poses, get_inv_pose, add_pose, pairwise_loss
from .oee.utils.elev_est_api import elev_est_api, ElevEstHelper
from .sampling import sample_images
def load_image(img_path, mask_path=None, preprocessor=None, threshold=0.9):
img = Image.open(img_path)
if preprocessor is not None:
img = load_and_preprocess(preprocessor, img)
else:
if img.mode == 'RGBA':
img = np.asarray(img, dtype=np.float32) / 255.
img[img[:, :, -1] <= threshold] = [1., 1., 1., 1.] # thresholding background
img = img[:, :, :3]
elif img.mode == 'RGB':
if mask_path is not None:
mask = Image.open(mask_path)
bkg = Image.new('RGB', (img.width, img.height), color=(255, 255, 255))
img = Image.composite(img, bkg, mask)
img = np.asarray(img, dtype=np.float32) / 255.
else:
print('Wrong format:', img_path)
return img
def load_model_from_config(config, ckpt, device, verbose=False):
print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location=device)
if 'global_step' in pl_sd:
step = pl_sd['global_step']
print(f'Global Step: {step}')
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print('missing keys:')
print(m)
if len(u) > 0 and verbose:
print('unexpected keys:')
print(u)
model.to(device)
model.eval()
return model
def estimate_elevs(model, images, est_type=None, matcher_ckpt_path=None):
num = len(images)
elevs = {i: None for i in range(num)}
elev_ranges = {i: None for i in range(num)}
if est_type == 'all':
matcher = ElevEstHelper.get_feature_matcher(matcher_ckpt_path, model.device)
for i in range(num):
simgs = sample_surrounding_images(model, images[i])
elev = elev_est_api(matcher, simgs, min_elev=20, max_elev=160)
elevs[i] = elev
for i in range(num):
if elevs[i] is not None:
elevs[i] = np.deg2rad(elevs[i])
for i in range(1, num):
if elevs[i] is not None and elevs[0] is not None:
elev_ranges[i] = np.array([ elevs[i] - elevs[0] ])
elif elevs[i] is not None:
elev_ranges[i] = -make_elev_probe_range(elevs[i])
elif elevs[0] is not None:
elev_ranges[i] = make_elev_probe_range(elevs[0])
elif est_type == 'simple':
matcher = ElevEstHelper.get_feature_matcher(matcher_ckpt_path, model.device)
simgs = sample_surrounding_images(model, images[0])
elev = elev_est_api(matcher, simgs, min_elev=20, max_elev=160)
elevs[0] = np.deg2rad(elev) if elev is not None else None
ae = elevs[0] if elevs[0] is not None else np.pi/2
for i in range(1, num):
elev_ranges[i] = np.array([np.pi/2 - ae])
return elevs, elev_ranges
def estimate_poses(
model, images,
seed_cand_num=8,
explore_type='pairwise',
refine_type='pairwise',
probe_ts_range=[0.02, 0.98], ts_range=[0.02, 0.98],
probe_bsz=16,
adjust_factor=10.,
adjust_iters=10,
adjust_bsz=1,
refine_factor=1.,
refine_iters=600,
refine_bsz=1,
noise=None,
elevs=None,
elev_ranges=None
):
num = len(images)
if elevs is None:
elevs = {i: None for i in range(num)}
if elev_ranges is None:
elev_ranges = {i: None for i in range(num)}
if num <= 2:
explore_type = 'pairwise'
cands = {}
losses = {}
ep_poses = {i: None for i in range(num)}
pairwise_ep_poses = {i: None for i in range(num)}
print('Start', datetime.now())
images = [ img.permute(0, 2, 3, 1) for img in images ]
for i in range(1, num):
print('PAIR', 0, i, datetime.now())
azimuth_range = np.arange(start=0.0, stop=np.pi*2, step=np.pi*2 / seed_cand_num)
all_cands = probe_pose(model, images[0], images[i], probe_ts_range, probe_bsz, theta_range=elev_ranges[i], azimuth_range=azimuth_range, noise=noise)
all_cands = sorted(all_cands)
print('Exploration', len(all_cands), datetime.now())
adjusted_cands = all_cands[:5]
if adjust_iters > 0:
adjusted_cands = []
'''only adjust the first half'''
for cand in all_cands[:len(all_cands)//2]:
out_poses, _, _ = find_optimal_poses(
model, [images[0], images[i]],
adjust_factor,
bsz=adjust_bsz,
n_iter=adjust_iters,
init_poses={1: cand[1]},
ts_range=ts_range,
print_n=100,
avg_last_n=1
)
loss = pairwise_loss(out_poses[0], model, images[0], images[i], probe_ts_range, probe_bsz, noise=noise)
adjusted_cands.append((loss, out_poses[0], cand[0], cand[1]))
adjusted_cands = sorted(adjusted_cands)[:5]
for cand in adjusted_cands:
print(cand)
cands[i] = [ cand[:2] for cand in adjusted_cands ]
losses[i] = [loss if (explore_type == 'pairwise') else 0.0 for loss, _ in cands[i]]
pairwise_ep_poses[i] = min(cands[i])[1]
print('Selection', datetime.now())
if explore_type == 'triangular':
for i in range(1, num):
for j in range(i+1, num):
iloss = [ [None for v in range(0, len(cands[j]))] for u in range(0, len(cands[i])) ]
jloss = [ [None for u in range(0, len(cands[i]))] for v in range(0, len(cands[j])) ]
for u in range(0, len(cands[i])):
la, pa = cands[i][u]
# pose i -> 0
pa = get_inv_pose(pa)
for v in range(0, len(cands[j])):
# pose 0 -> j
lb, pb = cands[j][v]
theta, azimuth, radius = add_pose(pa, pb)
lp = pairwise_loss([theta, azimuth, radius], model, images[i], images[j], probe_ts_range, probe_bsz, noise=noise)
iloss[u][v] = la + lb + lp
jloss[v][u] = la + lb + lp
for u in range(0, len(cands[i])):
losses[i][u] += min(min(iloss[u]), cands[i][u][0]*3)
for v in range(0, len(cands[j])):
losses[j][v] += min(min(jloss[v]), cands[j][v][0]*3)
for i in range(1, num):
ranks = sorted([x for x in range(0, len(losses[i]))], key=lambda x: losses[i][x])
min_rank = ranks[0]
for u in range(0, len(cands[i])):
print(cands[i][u], losses[i][u])
print(i, 'SELECT', min_rank, losses[i][min_rank])
ep_poses[i] = cands[i][min_rank][1]
print('Refinement', datetime.now())
combinations = None
if refine_type == 'pairwise':
combinations = [ (0, i) for i in range(1, num) ] + [ (i, 0) for i in range(1, num) ]
elif refine_type == 'triangular':
combinations = []
for i in range(0, num):
for j in range(i+1, num):
combinations.append((i, j))
combinations.append((j, i))
print('Combinations', len(combinations), combinations)
'''Refinement'''
out_poses, _, loss = find_optimal_poses(
model, images,
refine_factor,
bsz=refine_bsz,
n_iter=(num-1)*refine_iters,
init_poses=ep_poses,
ts_range=ts_range,
combinations=combinations,
avg_last_n=20,
print_n=100
)
print('Done', datetime.now())
aux_data = {
'tri_ep_sph': ep_poses,
'pw_ep_sph': pairwise_ep_poses,
'elev': elevs
}
return out_poses, aux_data
def make_elev_probe_range(elev, interval=np.pi/4):
up_range = np.arange(elev, 0, -interval)
down_range = np.arange(elev+interval, np.pi, interval)
probe_range = np.concatenate([up_range, down_range])
probe_range -= elev
return probe_range
def sample_surrounding_images(model, image):
s0 = sample_images(model, image, float(np.deg2rad(-10)), 0, 0, n_samples=1)
s1 = sample_images(model, image, float(np.deg2rad(+10)), 0, 0, n_samples=1)
s2 = sample_images(model, image, 0, float(np.deg2rad(-10)), 0, n_samples=1)
s3 = sample_images(model, image, 0, float(np.deg2rad(+10)), 0, n_samples=1)
return s0 + s1 + s2 + s3