File size: 4,373 Bytes
216282e
 
1fae98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216282e
1fae98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import json
import numpy as np
import cv2
from PIL import Image

# contrast correction, rescale and recenter
def image_preprocess_nosave(input_image, lower_contrast=True, rescale=True):

    image_arr = np.array(input_image)
    in_w, in_h = image_arr.shape[:2]

    if lower_contrast:
        alpha = 0.8  # Contrast control (1.0-3.0)
        beta =  0   # Brightness control (0-100)
        # Apply the contrast adjustment
        image_arr = cv2.convertScaleAbs(image_arr, alpha=alpha, beta=beta)
        image_arr[image_arr[...,-1]>200, -1] = 255

    ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY)
    x, y, w, h = cv2.boundingRect(mask)
    max_size = max(w, h)
    ratio = 0.75
    if rescale:
        side_len = int(max_size / ratio)
    else:
        side_len = in_w
    padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
    center = side_len//2
    padded_image[center-h//2:center-h//2+h, center-w//2:center-w//2+w] = image_arr[y:y+h, x:x+w]
    rgba = Image.fromarray(padded_image).resize((256, 256), Image.LANCZOS)

    rgba_arr = np.array(rgba) / 255.0
    rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:])
    return Image.fromarray((rgb * 255).astype(np.uint8))

# pose generation
def calc_pose(phis, thetas, size, radius = 1.2, device='cuda'):
    import torch
    def normalize(vectors):
        return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
    thetas = torch.FloatTensor(thetas).to(device)
    phis = torch.FloatTensor(phis).to(device)
    
    centers = torch.stack([
        radius * torch.sin(thetas) * torch.sin(phis),
        -radius * torch.cos(thetas) * torch.sin(phis),
        radius * torch.cos(phis),
    ], dim=-1) # [B, 3]

    # lookat
    forward_vector = normalize(centers).squeeze(0)
    up_vector = torch.FloatTensor([0, 0, 1]).to(device).unsqueeze(0).repeat(size, 1) 
    right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1))      
    if right_vector.pow(2).sum() < 0.01:
        right_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)  
    up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1))     

    poses = torch.eye(4, dtype=torch.float, device=device)[:3].unsqueeze(0).repeat(size, 1, 1)
    poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
    poses[:, :3, 3] = centers 
    return poses

def get_poses(init_elev):
    mid = init_elev
    deg = 10
    if init_elev <= 75:
        low = init_elev + 30
        # e.g. 30, 60, 20, 40, 30, 30, 50, 70, 50, 50
        
        elevations = np.radians([mid]*4 + [low]*4 + [mid-deg,mid+deg,mid,mid]*4 + [low-deg,low+deg,low,low]*4)
        img_ids = [f"{num}.png" for num in range(8)] + [f"{num}_{view_num}.png" for num in range(8) for view_num in range(4)]
    else:
        
        high = init_elev - 30
        elevations = np.radians([mid]*4 + [high]*4 + [mid-deg,mid+deg,mid,mid]*4 + [high-deg,high+deg,high,high]*4)
        img_ids = [f"{num}.png" for num in list(range(4)) + list(range(8,12))]  + \
                [f"{num}_{view_num}.png" for num in list(range(4)) + list(range(8,12)) for view_num in range(4)]
    overlook_theta = [30+x*90 for x in range(4)]
    eyelevel_theta = [60+x*90 for x in range(4)]
    source_theta_delta = [0, 0, -deg, deg]
    azimuths = np.radians(overlook_theta + eyelevel_theta + \
                            [view_theta + source for view_theta in overlook_theta for source in source_theta_delta] + \
                            [view_theta + source for view_theta in eyelevel_theta for source in source_theta_delta])
    return img_ids, calc_pose(elevations, azimuths, len(azimuths)).cpu().numpy()


def gen_poses(shape_dir, pose_est):
    img_ids, input_poses = get_poses(pose_est)
        
    out_dict = {}
    focal = 560/2; h = w = 256
    out_dict['intrinsics'] = [[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]]
    out_dict['near_far'] = [1.2-0.7, 1.2+0.7]
    out_dict['c2ws'] = {}
    for view_id, img_id in enumerate(img_ids):
        pose = input_poses[view_id]
        pose = pose.tolist()
        pose = [pose[0], pose[1], pose[2], [0, 0, 0, 1]]
        out_dict['c2ws'][img_id] = pose
    json_path = os.path.join(shape_dir, 'pose.json')
    with open(json_path, 'w') as f:
        json.dump(out_dict, f, indent=4)