|
|
|
import os |
|
try: |
|
import perspective2d |
|
except: |
|
os.system(f"pip install git+https://github.com/jinlinyi/PerspectiveFields.git@v1.0.0") |
|
|
|
|
|
import gradio as gr |
|
import cv2 |
|
import copy |
|
import numpy as np |
|
import os.path as osp |
|
from datetime import datetime |
|
|
|
import torch |
|
from PIL import Image, ImageDraw |
|
from glob import glob |
|
|
|
from perspective2d import PerspectiveFields |
|
from perspective2d.utils import draw_perspective_fields, draw_from_r_p_f_cx_cy |
|
from perspective2d.perspectivefields import model_zoo |
|
|
|
|
|
|
|
title = "" |
|
|
|
description = "" |
|
|
|
|
|
article = "" |
|
|
|
|
|
def resize_fix_aspect_ratio(img, field, target_width=None, target_height=None): |
|
height = img.shape[0] |
|
width = img.shape[1] |
|
if target_height is None: |
|
factor = target_width / width |
|
elif target_width is None: |
|
factor = target_height / height |
|
else: |
|
factor = max(target_width / width, target_height / height) |
|
if factor == target_width / width: |
|
target_height = int(height * factor) |
|
else: |
|
target_width = int(width * factor) |
|
|
|
img = cv2.resize(img, (target_width, target_height)) |
|
for key in field: |
|
if key not in ['up', 'lati']: |
|
continue |
|
tmp = field[key].numpy() |
|
transpose = len(tmp.shape) == 3 |
|
if transpose: |
|
tmp = tmp.transpose(1,2,0) |
|
tmp = cv2.resize(tmp, (target_width, target_height)) |
|
if transpose: |
|
tmp = tmp.transpose(2,0,1) |
|
field[key] = torch.tensor(tmp) |
|
return img, field |
|
|
|
|
|
def inference(img_rgb, model_type): |
|
if model_type is None: |
|
return None, "" |
|
pf_model = PerspectiveFields(model_type).eval().to(device) |
|
pred = pf_model.inference(img_bgr=img_rgb[...,::-1]) |
|
img_h = img_rgb.shape[0] |
|
field = { |
|
'up': pred['pred_gravity_original'].cpu().detach(), |
|
'lati': pred['pred_latitude_original'].cpu().detach(), |
|
} |
|
img_rgb, field = resize_fix_aspect_ratio(img_rgb, field, 640) |
|
if not model_zoo[model_type]['param']: |
|
pred_vis = draw_perspective_fields( |
|
img_rgb, |
|
field['up'], |
|
torch.deg2rad(field['lati']), |
|
color=(0,1,0), |
|
) |
|
param = "Not Implemented" |
|
else: |
|
r_p_f_rad = np.radians( |
|
[ |
|
pred['pred_roll'].cpu().item(), |
|
pred['pred_pitch'].cpu().item(), |
|
pred['pred_general_vfov'].cpu().item(), |
|
] |
|
) |
|
cx_cy = [ |
|
pred['pred_rel_cx'].cpu().item(), |
|
pred['pred_rel_cy'].cpu().item(), |
|
] |
|
param = f"roll {pred['pred_roll'].cpu().item() :.2f}\npitch {pred['pred_pitch'].cpu().item() :.2f}\nvertical fov {pred['pred_general_vfov'].cpu().item() :.2f}\nfocal_length {pred['pred_rel_focal'].cpu().item()*img_h :.2f}\n" |
|
param += f"principal point {pred['pred_rel_cx'].cpu().item() :.2f} {pred['pred_rel_cy'].cpu().item() :.2f}" |
|
pred_vis = draw_from_r_p_f_cx_cy( |
|
img_rgb, |
|
*r_p_f_rad, |
|
*cx_cy, |
|
'rad', |
|
up_color=(0,1,0), |
|
) |
|
print(f"""time {datetime.now().strftime("%H:%M:%S")} |
|
img.shape {img_rgb.shape} |
|
model_type {model_type} |
|
param {param} |
|
""" |
|
) |
|
return Image.fromarray(pred_vis), param |
|
|
|
examples = [] |
|
for img_name in glob('assets/imgs/*.*g'): |
|
examples.append([img_name]) |
|
print(examples) |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
info = """Select model\n""" |
|
gr.Interface( |
|
fn=inference, |
|
inputs=[ |
|
"image", |
|
gr.Radio( |
|
list(model_zoo.keys()), |
|
value=list(sorted(model_zoo.keys()))[0], |
|
label="Model", |
|
info=info, |
|
), |
|
], |
|
outputs=[gr.Image(label='Perspective Fields'), gr.Textbox(label='Pred Camera Parameters')], |
|
title=title, |
|
description=description, |
|
article=article, |
|
examples=examples, |
|
).launch() |