import numpy as np import torch import random # Reworked so this matches gluPerspective / glm::perspective, using fovy def perspective(fovx=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): # y = np.tan(fovy / 2) x = np.tan(fovx / 2) return torch.tensor([[1/x, 0, 0, 0], [ 0, -aspect/x, 0, 0], [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], [ 0, 0, -1, 0]], dtype=torch.float32, device=device) def translate(x, y, z, device=None): return torch.tensor([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]], dtype=torch.float32, device=device) def rotate_x(a, device=None): s, c = np.sin(a), np.cos(a) return torch.tensor([[1, 0, 0, 0], [0, c, -s, 0], [0, s, c, 0], [0, 0, 0, 1]], dtype=torch.float32, device=device) def rotate_y(a, device=None): s, c = np.sin(a), np.cos(a) return torch.tensor([[ c, 0, s, 0], [ 0, 1, 0, 0], [-s, 0, c, 0], [ 0, 0, 0, 1]], dtype=torch.float32, device=device) def rotate_z(a, device=None): s, c = np.sin(a), np.cos(a) return torch.tensor([[c, -s, 0, 0], [s, c, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=torch.float32, device=device) @torch.no_grad() def batch_random_rotation_translation(b, t, device=None): m = np.random.normal(size=[b, 3, 3]) m[:, 1] = np.cross(m[:, 0], m[:, 2]) m[:, 2] = np.cross(m[:, 0], m[:, 1]) m = m / np.linalg.norm(m, axis=2, keepdims=True) m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant') m[:, 3, 3] = 1.0 m[:, :3, 3] = np.random.uniform(-t, t, size=[b, 3]) return torch.tensor(m, dtype=torch.float32, device=device) @torch.no_grad() def random_rotation_translation(t, device=None): m = np.random.normal(size=[3, 3]) m[1] = np.cross(m[0], m[2]) m[2] = np.cross(m[0], m[1]) m = m / np.linalg.norm(m, axis=1, keepdims=True) m = np.pad(m, [[0, 1], [0, 1]], mode='constant') m[3, 3] = 1.0 m[:3, 3] = np.random.uniform(-t, t, size=[3]) return torch.tensor(m, dtype=torch.float32, device=device) @torch.no_grad() def random_rotation(device=None): m = np.random.normal(size=[3, 3]) m[1] = np.cross(m[0], m[2]) m[2] = np.cross(m[0], m[1]) m = m / np.linalg.norm(m, axis=1, keepdims=True) m = np.pad(m, [[0, 1], [0, 1]], mode='constant') m[3, 3] = 1.0 m[:3, 3] = np.array([0,0,0]).astype(np.float32) return torch.tensor(m, dtype=torch.float32, device=device) def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.sum(x*y, -1, keepdim=True) def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: return torch.sqrt(torch.clamp(dot(x,x), min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor: return x / length(x, eps) def lr_schedule(iter, warmup_iter, scheduler_decay): if iter < warmup_iter: return iter / warmup_iter return max(0.0, 10 ** ( -(iter - warmup_iter) * scheduler_decay)) def trans_depth(depth): depth = depth[0].detach().cpu().numpy() valid = depth > 0 depth[valid] -= depth[valid].min() depth[valid] = ((depth[valid] / depth[valid].max()) * 255) return depth.astype('uint8') def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): assert isinstance(input, torch.Tensor) if posinf is None: posinf = torch.finfo(input.dtype).max if neginf is None: neginf = torch.finfo(input.dtype).min assert nan == 0 return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) def load_item(filepath): with open(filepath, 'r') as f: items = [name.strip() for name in f.readlines()] return set(items) def load_prompt(filepath): uuid2prompt = {} with open(filepath, 'r') as f: for line in f.readlines(): list_line = line.split(',') uuid2prompt[list_line[0]] = ','.join(list_line[1:]).strip() return uuid2prompt def resize_and_center_image(image_tensor, scale=0.95, c = 0, shift = 0, rgb=False, aug_shift = 0): if scale == 1: return image_tensor B, C, H, W = image_tensor.shape new_H, new_W = int(H * scale), int(W * scale) resized_image = torch.nn.functional.interpolate(image_tensor, size=(new_H, new_W), mode='bilinear', align_corners=False).squeeze(0) background = torch.zeros_like(image_tensor) + c start_y, start_x = (H - new_H) // 2, (W - new_W) // 2 if shift == 0: background[:, :, start_y:start_y + new_H, start_x:start_x + new_W] = resized_image else: for i in range(B): randx = random.randint(-shift, shift) randy = random.randint(-shift, shift) if rgb == True: if i == 0 or i==2 or i==4: randx = 0 randy = 0 background[i, :, start_y+randy:start_y + new_H+randy, start_x+randx:start_x + new_W+randx] = resized_image[i] if aug_shift == 0: return background for i in range(B): for j in range(C): background[i, j, :, :] += (random.random() - 0.5)*2 * aug_shift / 255 return background def get_tri(triview_color, dim = 1, blender=True, c = 0, scale=0.95, shift = 0, fix = False, rgb=False, aug_shift = 0): # triview_color: [6,C,H,W] # rgb is useful when shift is not 0 triview_color = resize_and_center_image(triview_color, scale=scale, c = c, shift=shift,rgb=rgb, aug_shift = aug_shift) if blender is False: triview_color0 = torch.rot90(triview_color[0],k=2,dims=[1,2]) triview_color1 = torch.rot90(triview_color[4],k=1,dims=[1,2]).flip(2).flip(1) triview_color2 = torch.rot90(triview_color[5],k=1,dims=[1,2]).flip(2) triview_color3 = torch.rot90(triview_color[3],k=2,dims=[1,2]).flip(2) triview_color4 = torch.rot90(triview_color[1],k=3,dims=[1,2]).flip(1) triview_color5 = torch.rot90(triview_color[2],k=3,dims=[1,2]).flip(1).flip(2) else: triview_color0 = torch.rot90(triview_color[2],k=2,dims=[1,2]) triview_color1 = torch.rot90(triview_color[4],k=0,dims=[1,2]).flip(2).flip(1) triview_color2 = torch.rot90(torch.rot90(triview_color[0],k=3,dims=[1,2]).flip(2), k=2,dims=[1,2]) triview_color3 = torch.rot90(torch.rot90(triview_color[5],k=2,dims=[1,2]).flip(2), k=2,dims=[1,2]) triview_color4 = torch.rot90(triview_color[1],k=2,dims=[1,2]).flip(1).flip(1).flip(2) triview_color5 = torch.rot90(triview_color[3],k=1,dims=[1,2]).flip(1).flip(2) if fix == True: triview_color0[1] = triview_color0[1] * 0 triview_color0[2] = triview_color0[2] * 0 triview_color3[1] = triview_color3[1] * 0 triview_color3[2] = triview_color3[2] * 0 triview_color1[0] = triview_color1[0] * 0 triview_color1[1] = triview_color1[1] * 0 triview_color4[0] = triview_color4[0] * 0 triview_color4[1] = triview_color4[1] * 0 triview_color2[0] = triview_color2[0] * 0 triview_color2[2] = triview_color2[2] * 0 triview_color5[0] = triview_color5[0] * 0 triview_color5[2] = triview_color5[2] * 0 color_tensor1_gt = torch.cat((triview_color0, triview_color1, triview_color2), dim=2) color_tensor2_gt = torch.cat((triview_color3, triview_color4, triview_color5), dim=2) color_tensor_gt = torch.cat((color_tensor1_gt, color_tensor2_gt), dim = dim) return color_tensor_gt