| import os | |
| try: | |
| import perspective2d | |
| except: | |
| os.system(f"pip install git+https://github.com/jinlinyi/PerspectiveFields.git@v1.0.0") | |
| import gradio as gr | |
| import cv2 | |
| import copy | |
| import numpy as np | |
| import os.path as osp | |
| from datetime import datetime | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from glob import glob | |
| from perspective2d import PerspectiveFields | |
| from perspective2d.utils import draw_perspective_fields, draw_from_r_p_f_cx_cy | |
| from perspective2d.perspectivefields import model_zoo | |
| title = "" | |
| description = "" | |
| article = "" | |
| def resize_fix_aspect_ratio(img, field, target_width=None, target_height=None): | |
| height = img.shape[0] | |
| width = img.shape[1] | |
| if target_height is None: | |
| factor = target_width / width | |
| elif target_width is None: | |
| factor = target_height / height | |
| else: | |
| factor = max(target_width / width, target_height / height) | |
| if factor == target_width / width: | |
| target_height = int(height * factor) | |
| else: | |
| target_width = int(width * factor) | |
| img = cv2.resize(img, (target_width, target_height)) | |
| for key in field: | |
| if key not in ['up', 'lati']: | |
| continue | |
| tmp = field[key].numpy() | |
| transpose = len(tmp.shape) == 3 | |
| if transpose: | |
| tmp = tmp.transpose(1,2,0) | |
| tmp = cv2.resize(tmp, (target_width, target_height)) | |
| if transpose: | |
| tmp = tmp.transpose(2,0,1) | |
| field[key] = torch.tensor(tmp) | |
| return img, field | |
| def inference(img_rgb, model_type): | |
| if model_type is None: | |
| return None, "" | |
| pf_model = PerspectiveFields(model_type).eval().to(device) | |
| pred = pf_model.inference(img_bgr=img_rgb[...,::-1]) | |
| img_h = img_rgb.shape[0] | |
| field = { | |
| 'up': pred['pred_gravity_original'].cpu().detach(), | |
| 'lati': pred['pred_latitude_original'].cpu().detach(), | |
| } | |
| img_rgb, field = resize_fix_aspect_ratio(img_rgb, field, 640) | |
| if not model_zoo[model_type]['param']: | |
| pred_vis = draw_perspective_fields( | |
| img_rgb, | |
| field['up'], | |
| torch.deg2rad(field['lati']), | |
| color=(0,1,0), | |
| ) | |
| param = "Not Implemented" | |
| else: | |
| r_p_f_rad = np.radians( | |
| [ | |
| pred['pred_roll'].cpu().item(), | |
| pred['pred_pitch'].cpu().item(), | |
| pred['pred_general_vfov'].cpu().item(), | |
| ] | |
| ) | |
| cx_cy = [ | |
| pred['pred_rel_cx'].cpu().item(), | |
| pred['pred_rel_cy'].cpu().item(), | |
| ] | |
| param = f"roll {pred['pred_roll'].cpu().item() :.2f}\npitch {pred['pred_pitch'].cpu().item() :.2f}\nvertical fov {pred['pred_general_vfov'].cpu().item() :.2f}\nfocal_length {pred['pred_rel_focal'].cpu().item()*img_h :.2f}\n" | |
| param += f"principal point {pred['pred_rel_cx'].cpu().item() :.2f} {pred['pred_rel_cy'].cpu().item() :.2f}" | |
| pred_vis = draw_from_r_p_f_cx_cy( | |
| img_rgb, | |
| *r_p_f_rad, | |
| *cx_cy, | |
| 'rad', | |
| up_color=(0,1,0), | |
| ) | |
| print(f"""time {datetime.now().strftime("%H:%M:%S")} | |
| img.shape {img_rgb.shape} | |
| model_type {model_type} | |
| param {param} | |
| """ | |
| ) | |
| return Image.fromarray(pred_vis), param | |
| examples = [] | |
| for img_name in glob('assets/imgs/*.*g'): | |
| examples.append([img_name]) | |
| print(examples) | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| info = """Select model\n""" | |
| gr.Interface( | |
| fn=inference, | |
| inputs=[ | |
| "image", | |
| gr.Radio( | |
| list(model_zoo.keys()), | |
| value=list(sorted(model_zoo.keys()))[0], | |
| label="Model", | |
| info=info, | |
| ), | |
| ], | |
| outputs=[gr.Image(label='Perspective Fields'), gr.Textbox(label='Pred Camera Parameters')], | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=examples, | |
| ).launch() | 
