Spaces:
Sleeping
Sleeping
kyleleey
commited on
Commit
•
3bf37d0
1
Parent(s):
d2d9973
init demo
Browse files- .gitignore +4 -13
- app.py +619 -0
- requirements.txt +32 -0
.gitignore
CHANGED
@@ -1,22 +1,13 @@
|
|
1 |
__pycache__
|
2 |
-
data
|
3 |
-
data/*/
|
4 |
-
data/*/*
|
5 |
-
|
6 |
pretrained/*/
|
7 |
results
|
8 |
neural_renderer
|
9 |
*.zip
|
10 |
unchanged/
|
11 |
-
cvpr23_results/
|
12 |
-
# slurm.bash
|
13 |
-
results
|
14 |
-
results/*/
|
15 |
-
results/*
|
16 |
-
results/*/*
|
17 |
-
results/dor_checkpoints/*
|
18 |
-
results/dor_checkpoints/*/*
|
19 |
-
results/dor_checkpoints/*/*/*
|
20 |
|
21 |
|
22 |
.vscode
|
|
|
1 |
__pycache__
|
2 |
+
# data
|
3 |
+
# data/*/
|
4 |
+
# data/*/*
|
5 |
+
|
6 |
pretrained/*/
|
7 |
results
|
8 |
neural_renderer
|
9 |
*.zip
|
10 |
unchanged/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
.vscode
|
app.py
ADDED
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import fire
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
from functools import partial
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import time
|
10 |
+
import numpy as np
|
11 |
+
import trimesh
|
12 |
+
from segment_anything import sam_model_registry, SamPredictor
|
13 |
+
|
14 |
+
import random
|
15 |
+
from pytorch3d import transforms
|
16 |
+
import torch
|
17 |
+
import torchvision
|
18 |
+
import torch.distributed as dist
|
19 |
+
import nvdiffrast.torch as dr
|
20 |
+
from video3d.model_ddp import Unsup3DDDP, forward_to_matrix
|
21 |
+
from video3d.trainer_few_shot import Fewshot_Trainer
|
22 |
+
from video3d.trainer_ddp import TrainerDDP
|
23 |
+
from video3d import setup_runtime
|
24 |
+
from video3d.render.mesh import make_mesh
|
25 |
+
from video3d.utils.skinning_v4 import estimate_bones, skinning, euler_angles_to_matrix
|
26 |
+
from video3d.utils.misc import save_obj
|
27 |
+
from video3d.render import util
|
28 |
+
import matplotlib.pyplot as plt
|
29 |
+
from pytorch3d import utils, renderer, transforms, structures, io
|
30 |
+
from video3d.render.render import render_mesh
|
31 |
+
from video3d.render.material import texture as material_texture
|
32 |
+
|
33 |
+
|
34 |
+
_TITLE = '''Learning the 3D Fauna of the Web'''
|
35 |
+
_DESCRIPTION = '''
|
36 |
+
<div>
|
37 |
+
Reconstruct any quadruped animal from one image.
|
38 |
+
</div>
|
39 |
+
<div>
|
40 |
+
The demo only contains the 3D reconstruction part.
|
41 |
+
</div>
|
42 |
+
'''
|
43 |
+
_GPU_ID = 0
|
44 |
+
|
45 |
+
if not hasattr(Image, 'Resampling'):
|
46 |
+
Image.Resampling = Image
|
47 |
+
|
48 |
+
|
49 |
+
def sam_init():
|
50 |
+
sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth")
|
51 |
+
model_type = "vit_h"
|
52 |
+
|
53 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}")
|
54 |
+
predictor = SamPredictor(sam)
|
55 |
+
return predictor
|
56 |
+
|
57 |
+
|
58 |
+
def sam_segment(predictor, input_image, *bbox_coords):
|
59 |
+
bbox = np.array(bbox_coords)
|
60 |
+
image = np.asarray(input_image)
|
61 |
+
|
62 |
+
start_time = time.time()
|
63 |
+
predictor.set_image(image)
|
64 |
+
|
65 |
+
masks_bbox, scores_bbox, logits_bbox = predictor.predict(
|
66 |
+
box=bbox,
|
67 |
+
multimask_output=True
|
68 |
+
)
|
69 |
+
|
70 |
+
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
71 |
+
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
72 |
+
out_image[:, :, :3] = image
|
73 |
+
out_image_bbox = out_image.copy()
|
74 |
+
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255
|
75 |
+
torch.cuda.empty_cache()
|
76 |
+
return Image.fromarray(out_image_bbox, mode='RGB')
|
77 |
+
# return Image.fromarray(out_image_bbox, mode='RGBA')
|
78 |
+
|
79 |
+
|
80 |
+
def expand2square(pil_img, background_color):
|
81 |
+
width, height = pil_img.size
|
82 |
+
if width == height:
|
83 |
+
return pil_img
|
84 |
+
elif width > height:
|
85 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
86 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
87 |
+
return result
|
88 |
+
else:
|
89 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
90 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
91 |
+
return result
|
92 |
+
|
93 |
+
|
94 |
+
def preprocess(predictor, input_image, chk_group=None, segment=True):
|
95 |
+
RES = 1024
|
96 |
+
input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS)
|
97 |
+
if chk_group is not None:
|
98 |
+
segment = "Use SAM to center animal" in chk_group
|
99 |
+
if segment:
|
100 |
+
image_rem = input_image.convert('RGB')
|
101 |
+
arr = np.asarray(image_rem)[:,:,-1]
|
102 |
+
x_nonzero = np.nonzero(arr.sum(axis=0))
|
103 |
+
y_nonzero = np.nonzero(arr.sum(axis=1))
|
104 |
+
x_min = int(x_nonzero[0].min())
|
105 |
+
y_min = int(y_nonzero[0].min())
|
106 |
+
x_max = int(x_nonzero[0].max())
|
107 |
+
y_max = int(y_nonzero[0].max())
|
108 |
+
input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max)
|
109 |
+
# Rescale and recenter
|
110 |
+
# if rescale:
|
111 |
+
# image_arr = np.array(input_image)
|
112 |
+
# in_w, in_h = image_arr.shape[:2]
|
113 |
+
# out_res = min(RES, max(in_w, in_h))
|
114 |
+
# ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY)
|
115 |
+
# x, y, w, h = cv2.boundingRect(mask)
|
116 |
+
# max_size = max(w, h)
|
117 |
+
# ratio = 0.75
|
118 |
+
# side_len = int(max_size / ratio)
|
119 |
+
# padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
|
120 |
+
# center = side_len//2
|
121 |
+
# padded_image[center-h//2:center-h//2+h, center-w//2:center-w//2+w] = image_arr[y:y+h, x:x+w]
|
122 |
+
# rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS)
|
123 |
+
|
124 |
+
# rgba_arr = np.array(rgba) / 255.0
|
125 |
+
# rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:])
|
126 |
+
# input_image = Image.fromarray((rgb * 255).astype(np.uint8))
|
127 |
+
# else:
|
128 |
+
# input_image = expand2square(input_image, (127, 127, 127, 0))
|
129 |
+
|
130 |
+
input_image = expand2square(input_image, (0, 0, 0))
|
131 |
+
return input_image, input_image.resize((256, 256), Image.Resampling.LANCZOS)
|
132 |
+
|
133 |
+
|
134 |
+
def save_images(images, mask_pred, mode="transparent"):
|
135 |
+
img = images[0]
|
136 |
+
mask = mask_pred[0]
|
137 |
+
img = img.clamp(0, 1)
|
138 |
+
if mask is not None:
|
139 |
+
mask = mask.clamp(0, 1)
|
140 |
+
if mode == "white":
|
141 |
+
img = img * mask + 1 * (1 - mask)
|
142 |
+
elif mode == "black":
|
143 |
+
img = img * mask + 0 * (1 - mask)
|
144 |
+
else:
|
145 |
+
img = torch.cat([img, mask[0:1]], 0)
|
146 |
+
|
147 |
+
img = img.permute(1, 2, 0).cpu().numpy()
|
148 |
+
img = Image.fromarray(np.uint8(img * 255))
|
149 |
+
return img
|
150 |
+
|
151 |
+
|
152 |
+
def get_bank_embedding(rgb, memory_bank_keys, memory_bank, model, memory_bank_topk=10, memory_bank_dim=128):
|
153 |
+
images = rgb
|
154 |
+
batch_size, num_frames, _, h0, w0 = images.shape
|
155 |
+
images = images.reshape(batch_size*num_frames, *images.shape[2:]) # 0~1
|
156 |
+
images_in = images * 2 - 1 # rescale to (-1, 1) for DINO
|
157 |
+
|
158 |
+
x = images_in
|
159 |
+
with torch.no_grad():
|
160 |
+
b, c, h, w = x.shape
|
161 |
+
model.netInstance.netEncoder._feats = []
|
162 |
+
model.netInstance.netEncoder._register_hooks([11], 'key')
|
163 |
+
#self._register_hooks([11], 'token')
|
164 |
+
x = model.netInstance.netEncoder.ViT.prepare_tokens(x)
|
165 |
+
#x = self.ViT.prepare_tokens_with_masks(x)
|
166 |
+
|
167 |
+
for blk in model.netInstance.netEncoder.ViT.blocks:
|
168 |
+
x = blk(x)
|
169 |
+
out = model.netInstance.netEncoder.ViT.norm(x)
|
170 |
+
model.netInstance.netEncoder._unregister_hooks()
|
171 |
+
|
172 |
+
ph, pw = h // model.netInstance.netEncoder.patch_size, w // model.netInstance.netEncoder.patch_size
|
173 |
+
patch_out = out[:, 1:] # first is class token
|
174 |
+
patch_out = patch_out.reshape(b, ph, pw, model.netInstance.netEncoder.vit_feat_dim).permute(0, 3, 1, 2)
|
175 |
+
|
176 |
+
patch_key = model.netInstance.netEncoder._feats[0][:,:,1:] # B, num_heads, num_patches, dim
|
177 |
+
patch_key = patch_key.permute(0, 1, 3, 2).reshape(b, model.netInstance.netEncoder.vit_feat_dim, ph, pw)
|
178 |
+
|
179 |
+
global_feat = out[:, 0]
|
180 |
+
|
181 |
+
batch_features = global_feat
|
182 |
+
|
183 |
+
batch_size = batch_features.shape[0]
|
184 |
+
|
185 |
+
query = torch.nn.functional.normalize(batch_features.unsqueeze(1), dim=-1) # [B, 1, d_k]
|
186 |
+
key = torch.nn.functional.normalize(memory_bank_keys, dim=-1) # [size, d_k]
|
187 |
+
key = key.transpose(1, 0).unsqueeze(0).repeat(batch_size, 1, 1).to(query.device) # [B, d_k, size]
|
188 |
+
|
189 |
+
cos_dist = torch.bmm(query, key).squeeze(1) # [B, size], larger the more similar
|
190 |
+
rank_idx = torch.sort(cos_dist, dim=-1, descending=True)[1][:, :memory_bank_topk] # [B, k]
|
191 |
+
value = memory_bank.unsqueeze(0).repeat(batch_size, 1, 1).to(query.device) # [B, size, d_v]
|
192 |
+
|
193 |
+
out = torch.gather(value, dim=1, index=rank_idx[..., None].repeat(1, 1, memory_bank_dim)) # [B, k, d_v]
|
194 |
+
|
195 |
+
weights = torch.gather(cos_dist, dim=-1, index=rank_idx) # [B, k]
|
196 |
+
weights = torch.nn.functional.normalize(weights, p=1.0, dim=-1).unsqueeze(-1).repeat(1, 1, memory_bank_dim) # [B, k, d_v] weights have been normalized
|
197 |
+
|
198 |
+
out = weights * out
|
199 |
+
out = torch.sum(out, dim=1)
|
200 |
+
|
201 |
+
batch_mean_out = torch.mean(out, dim=0)
|
202 |
+
|
203 |
+
weight_aux = {
|
204 |
+
'weights': weights[:, :, 0], # [B, k], weights from large to small
|
205 |
+
'pick_idx': rank_idx, # [B, k]
|
206 |
+
}
|
207 |
+
|
208 |
+
batch_embedding = batch_mean_out
|
209 |
+
embeddings = out
|
210 |
+
weights = weight_aux
|
211 |
+
|
212 |
+
bank_embedding_model_input = [batch_embedding, embeddings, weights]
|
213 |
+
|
214 |
+
return bank_embedding_model_input
|
215 |
+
|
216 |
+
|
217 |
+
class FixedDirectionLight(torch.nn.Module):
|
218 |
+
def __init__(self, direction, amb, diff):
|
219 |
+
super(FixedDirectionLight, self).__init__()
|
220 |
+
self.light_dir = direction
|
221 |
+
self.amb = amb
|
222 |
+
self.diff = diff
|
223 |
+
self.is_hacking = not (isinstance(self.amb, float)
|
224 |
+
or isinstance(self.amb, int))
|
225 |
+
|
226 |
+
def forward(self, feat):
|
227 |
+
batch_size = feat.shape[0]
|
228 |
+
if self.is_hacking:
|
229 |
+
return torch.concat([self.light_dir, self.amb, self.diff], -1)
|
230 |
+
else:
|
231 |
+
return torch.concat([self.light_dir, torch.FloatTensor([self.amb, self.diff]).to(self.light_dir.device)], -1).expand(batch_size, -1)
|
232 |
+
|
233 |
+
def shade(self, feat, kd, normal):
|
234 |
+
light_params = self.forward(feat)
|
235 |
+
light_dir = light_params[..., :3][:, None, None, :]
|
236 |
+
int_amb = light_params[..., 3:4][:, None, None, :]
|
237 |
+
int_diff = light_params[..., 4:5][:, None, None, :]
|
238 |
+
shading = (int_amb + int_diff *
|
239 |
+
torch.clamp(util.dot(light_dir, normal), min=0.0))
|
240 |
+
shaded = shading * kd
|
241 |
+
return shaded, shading
|
242 |
+
|
243 |
+
|
244 |
+
def render_bones(mvp, bones_pred, size=(256, 256)):
|
245 |
+
bone_world4 = torch.concat([bones_pred, torch.ones_like(bones_pred[..., :1]).to(bones_pred.device)], dim=-1)
|
246 |
+
b, f, num_bones = bone_world4.shape[:3]
|
247 |
+
bones_clip4 = (bone_world4.view(b, f, num_bones*2, 1, 4) @ mvp.transpose(-1, -2).reshape(b, f, 1, 4, 4)).view(b, f, num_bones, 2, 4)
|
248 |
+
bones_uv = bones_clip4[..., :2] / bones_clip4[..., 3:4] # b, f, num_bones, 2, 2
|
249 |
+
dpi = 32
|
250 |
+
fx, fy = size[1] // dpi, size[0] // dpi
|
251 |
+
|
252 |
+
rendered = []
|
253 |
+
for b_idx in range(b):
|
254 |
+
for f_idx in range(f):
|
255 |
+
frame_bones_uv = bones_uv[b_idx, f_idx].cpu().numpy()
|
256 |
+
fig = plt.figure(figsize=(fx, fy), dpi=dpi, frameon=False)
|
257 |
+
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
258 |
+
ax.set_axis_off()
|
259 |
+
for bone in frame_bones_uv:
|
260 |
+
ax.plot(bone[:, 0], bone[:, 1], marker='o', linewidth=8, markersize=20)
|
261 |
+
ax.set_xlim(-1, 1)
|
262 |
+
ax.set_ylim(-1, 1)
|
263 |
+
ax.invert_yaxis()
|
264 |
+
# Convert to image
|
265 |
+
fig.add_axes(ax)
|
266 |
+
fig.canvas.draw_idle()
|
267 |
+
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
268 |
+
w, h = fig.canvas.get_width_height()
|
269 |
+
image.resize(h, w, 3)
|
270 |
+
rendered += [image / 255.]
|
271 |
+
return torch.from_numpy(np.stack(rendered, 0).transpose(0, 3, 1, 2)).to(bones_pred.device)
|
272 |
+
|
273 |
+
def add_mesh_color(mesh, color):
|
274 |
+
verts = mesh.verts_padded()
|
275 |
+
color = torch.FloatTensor(color).to(verts.device).view(1,1,3) / 255
|
276 |
+
mesh.textures = renderer.TexturesVertex(verts_features=verts*0+color)
|
277 |
+
return mesh
|
278 |
+
|
279 |
+
def create_sphere(position, scale, device, color=[139, 149, 173]):
|
280 |
+
mesh = utils.ico_sphere(2).to(device)
|
281 |
+
mesh = mesh.extend(position.shape[0])
|
282 |
+
|
283 |
+
# scale and offset
|
284 |
+
mesh = mesh.update_padded(mesh.verts_padded() * scale + position[:, None])
|
285 |
+
|
286 |
+
mesh = add_mesh_color(mesh, color)
|
287 |
+
|
288 |
+
return mesh
|
289 |
+
|
290 |
+
def estimate_bone_rotation(b):
|
291 |
+
"""
|
292 |
+
(0, 0, 1) = matmul(R^(-1), b)
|
293 |
+
|
294 |
+
assumes x, y is a symmetry plane
|
295 |
+
|
296 |
+
returns R
|
297 |
+
"""
|
298 |
+
b = b / torch.norm(b, dim=-1, keepdim=True)
|
299 |
+
|
300 |
+
n = torch.FloatTensor([[1, 0, 0]]).to(b.device)
|
301 |
+
n = n.expand_as(b)
|
302 |
+
v = torch.cross(b, n, dim=-1)
|
303 |
+
|
304 |
+
R = torch.stack([n, v, b], dim=-1).transpose(-2, -1)
|
305 |
+
|
306 |
+
return R
|
307 |
+
|
308 |
+
def estimate_vector_rotation(vector_a, vector_b):
|
309 |
+
"""
|
310 |
+
vector_a = matmul(R, vector_b)
|
311 |
+
|
312 |
+
returns R
|
313 |
+
|
314 |
+
https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d
|
315 |
+
"""
|
316 |
+
vector_a = vector_a / torch.norm(vector_a, dim=-1, keepdim=True)
|
317 |
+
vector_b = vector_b / torch.norm(vector_b, dim=-1, keepdim=True)
|
318 |
+
|
319 |
+
v = torch.cross(vector_a, vector_b, dim=-1)
|
320 |
+
c = torch.sum(vector_a * vector_b, dim=-1)
|
321 |
+
|
322 |
+
skew = torch.stack([
|
323 |
+
torch.stack([torch.zeros_like(v[..., 0]), -v[..., 2], v[..., 1]], dim=-1),
|
324 |
+
torch.stack([v[..., 2], torch.zeros_like(v[..., 0]), -v[..., 0]], dim=-1),
|
325 |
+
torch.stack([-v[..., 1], v[..., 0], torch.zeros_like(v[..., 0])], dim=-1)],
|
326 |
+
dim=-1)
|
327 |
+
|
328 |
+
R = torch.eye(3, device=vector_a.device)[None] + skew + torch.matmul(skew, skew) / (1 + c[..., None, None])
|
329 |
+
|
330 |
+
return R
|
331 |
+
|
332 |
+
def create_elipsoid(bone, scale=0.05, color=[139, 149, 173], generic_rotation_estim=True):
|
333 |
+
length = torch.norm(bone[:, 0] - bone[:, 1], dim=-1)
|
334 |
+
|
335 |
+
mesh = utils.ico_sphere(2).to(bone.device)
|
336 |
+
mesh = mesh.extend(bone.shape[0])
|
337 |
+
# scale x, y
|
338 |
+
verts = mesh.verts_padded() * torch.FloatTensor([scale, scale, 1]).to(bone.device)
|
339 |
+
# stretch along z axis, set the start to origin
|
340 |
+
verts[:, :, 2] = verts[:, :, 2] * length[:, None] * 0.5 + length[:, None] * 0.5
|
341 |
+
|
342 |
+
bone_vector = bone[:, 1] - bone[:, 0]
|
343 |
+
z_vector = torch.FloatTensor([[0, 0, 1]]).to(bone.device)
|
344 |
+
z_vector = z_vector.expand_as(bone_vector)
|
345 |
+
if generic_rotation_estim:
|
346 |
+
rot = estimate_vector_rotation(z_vector, bone_vector)
|
347 |
+
else:
|
348 |
+
rot = estimate_bone_rotation(bone_vector)
|
349 |
+
tsf = transforms.Rotate(rot, device=bone.device)
|
350 |
+
tsf = tsf.compose(transforms.Translate(bone[:, 0], device=bone.device))
|
351 |
+
verts = tsf.transform_points(verts)
|
352 |
+
|
353 |
+
mesh = mesh.update_padded(verts)
|
354 |
+
|
355 |
+
mesh = add_mesh_color(mesh, color)
|
356 |
+
|
357 |
+
return mesh
|
358 |
+
|
359 |
+
def convert_textures_vertex_to_textures_uv(meshes: structures.Meshes, color1, color2) -> renderer.TexturesUV:
|
360 |
+
"""
|
361 |
+
Convert a TexturesVertex object to a TexturesUV object.
|
362 |
+
"""
|
363 |
+
color1 = torch.Tensor(color1).to(meshes.device).view(1, 1, 3) / 255
|
364 |
+
color2 = torch.Tensor(color2).to(meshes.device).view(1, 1, 3) / 255
|
365 |
+
textures_vertex = meshes.textures
|
366 |
+
assert isinstance(textures_vertex, renderer.TexturesVertex), "Input meshes must have TexturesVertex"
|
367 |
+
verts_rgb = textures_vertex.verts_features_padded()
|
368 |
+
faces_uvs = meshes.faces_padded()
|
369 |
+
batch_size = verts_rgb.shape[0]
|
370 |
+
maps = torch.zeros(batch_size, 128, 128, 3, device=verts_rgb.device)
|
371 |
+
maps[:, :, :64, :] = color1
|
372 |
+
maps[:, :, 64:, :] = color2
|
373 |
+
is_first = (verts_rgb == color1)[..., 0]
|
374 |
+
verts_uvs = torch.zeros(batch_size, verts_rgb.shape[1], 2, device=verts_rgb.device)
|
375 |
+
verts_uvs[is_first] = torch.FloatTensor([0.25, 0.5]).to(verts_rgb.device)
|
376 |
+
verts_uvs[~is_first] = torch.FloatTensor([0.75, 0.5]).to(verts_rgb.device)
|
377 |
+
textures_uv = renderer.TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
|
378 |
+
meshes.textures = textures_uv
|
379 |
+
return meshes
|
380 |
+
|
381 |
+
def create_bones_scene(bones, joint_color=[66, 91, 140], bone_color=[119, 144, 189], show_end_point=False):
|
382 |
+
meshes = []
|
383 |
+
for bone_i in range(bones.shape[1]):
|
384 |
+
# points
|
385 |
+
meshes += [create_sphere(bones[:, bone_i, 0], 0.1, bones.device, color=joint_color)]
|
386 |
+
if show_end_point:
|
387 |
+
meshes += [create_sphere(bones[:, bone_i, 1], 0.1, bones.device, color=joint_color)]
|
388 |
+
|
389 |
+
# connecting ellipsoid
|
390 |
+
meshes += [create_elipsoid(bones[:, bone_i], color=bone_color)]
|
391 |
+
|
392 |
+
current_batch_size = bones.shape[0]
|
393 |
+
meshes = [structures.join_meshes_as_scene([m[i] for m in meshes]) for i in range(current_batch_size)]
|
394 |
+
mesh = structures.join_meshes_as_batch(meshes)
|
395 |
+
|
396 |
+
return mesh
|
397 |
+
|
398 |
+
|
399 |
+
def run_pipeline(model_items, cfgs, input_img, device):
|
400 |
+
epoch = 999
|
401 |
+
total_iter = 999999
|
402 |
+
model = model_items[0]
|
403 |
+
memory_bank = model_items[1]
|
404 |
+
memory_bank_keys = model_items[2]
|
405 |
+
|
406 |
+
input_image = torch.stack([torchvision.transforms.ToTensor()(input_img)], dim=0).to(device)
|
407 |
+
|
408 |
+
with torch.no_grad():
|
409 |
+
model.netPrior.eval()
|
410 |
+
model.netInstance.eval()
|
411 |
+
input_image = torch.nn.functional.interpolate(input_image, size=(256, 256), mode='bilinear', align_corners=False)
|
412 |
+
input_image = input_image[:, None, :, :] # [B=1, F=1, 3, 256, 256]
|
413 |
+
|
414 |
+
bank_embedding = get_bank_embedding(
|
415 |
+
input_image,
|
416 |
+
memory_bank_keys,
|
417 |
+
memory_bank,
|
418 |
+
model,
|
419 |
+
memory_bank_topk=cfgs.get("memory_bank_topk", 10),
|
420 |
+
memory_bank_dim=128
|
421 |
+
)
|
422 |
+
|
423 |
+
prior_shape, dino_pred, classes_vectors = model.netPrior(
|
424 |
+
category_name='tmp',
|
425 |
+
perturb_sdf=False,
|
426 |
+
total_iter=total_iter,
|
427 |
+
is_training=False,
|
428 |
+
class_embedding=bank_embedding
|
429 |
+
)
|
430 |
+
Instance_out = model.netInstance(
|
431 |
+
'tmp',
|
432 |
+
input_image,
|
433 |
+
prior_shape,
|
434 |
+
epoch,
|
435 |
+
dino_features=None,
|
436 |
+
dino_clusters=None,
|
437 |
+
total_iter=total_iter,
|
438 |
+
is_training=False
|
439 |
+
) # frame dim collapsed N=(B*F)
|
440 |
+
if len(Instance_out) == 13:
|
441 |
+
shape, pose_raw, pose, mvp, w2c, campos, texture_pred, im_features, dino_feat_im_calc, deform, all_arti_params, light, forward_aux = Instance_out
|
442 |
+
im_features_map = None
|
443 |
+
else:
|
444 |
+
shape, pose_raw, pose, mvp, w2c, campos, texture_pred, im_features, dino_feat_im_calc, deform, all_arti_params, light, forward_aux, im_features_map = Instance_out
|
445 |
+
|
446 |
+
class_vector = classes_vectors # the bank embeddings
|
447 |
+
|
448 |
+
gray_light = FixedDirectionLight(direction=torch.FloatTensor([0, 0, 1]).to(device), amb=0.2, diff=0.7)
|
449 |
+
|
450 |
+
image_pred, mask_pred, _, _, _, shading = model.render(
|
451 |
+
shape, texture_pred, mvp, w2c, campos, 256, background=model.background_mode,
|
452 |
+
im_features=im_features, light=gray_light, prior_shape=prior_shape, render_mode='diffuse',
|
453 |
+
render_flow=False, dino_pred=None, im_features_map=im_features_map
|
454 |
+
)
|
455 |
+
mask_pred = mask_pred.expand_as(image_pred)
|
456 |
+
shading = shading.expand_as(image_pred)
|
457 |
+
# render bones in pytorch3D style
|
458 |
+
posed_bones = forward_aux["posed_bones"].squeeze(1)
|
459 |
+
jc, bc = [66, 91, 140], [119, 144, 189]
|
460 |
+
bones_meshes = create_bones_scene(posed_bones, joint_color=jc, bone_color=bc, show_end_point=True)
|
461 |
+
bones_meshes = convert_textures_vertex_to_textures_uv(bones_meshes, color1=jc, color2=bc)
|
462 |
+
nv_meshes = make_mesh(verts=bones_meshes.verts_padded(), faces=bones_meshes.faces_padded()[0:1],
|
463 |
+
uvs=bones_meshes.textures.verts_uvs_padded(), uv_idx=bones_meshes.textures.faces_uvs_padded()[0:1],
|
464 |
+
material=material_texture.Texture2D(bones_meshes.textures.maps_padded()))
|
465 |
+
buffers = render_mesh(dr.RasterizeGLContext(), nv_meshes, mvp, w2c, campos, nv_meshes.material, lgt=gray_light, feat=im_features, dino_pred=None, resolution=256, bsdf="diffuse")
|
466 |
+
|
467 |
+
shaded = buffers["shaded"].permute(0, 3, 1, 2)
|
468 |
+
bone_image = shaded[:, :3, :, :]
|
469 |
+
bone_mask = shaded[:, 3:, :, :]
|
470 |
+
mask_final = mask_pred.logical_or(bone_mask)
|
471 |
+
mask_final = mask_final.int()
|
472 |
+
image_with_bones = bone_image * bone_mask * 0.5 + (shading * (1 - bone_mask * 0.5) + 0.5 * (mask_final.float() - mask_pred.float()))
|
473 |
+
|
474 |
+
mesh_image = save_images(shading, mask_pred)
|
475 |
+
mesh_bones_image = save_images(image_with_bones, mask_final)
|
476 |
+
|
477 |
+
final_shape = shape.clone()
|
478 |
+
prior_shape = prior_shape.clone()
|
479 |
+
|
480 |
+
final_mesh_tri = trimesh.Trimesh(
|
481 |
+
vertices=final_shape.v_pos[0].detach().cpu().numpy(),
|
482 |
+
faces=final_shape.t_pos_idx[0].detach().cpu().numpy(),
|
483 |
+
process=False,
|
484 |
+
maintain_order=True)
|
485 |
+
prior_mesh_tri = trimesh.Trimesh(
|
486 |
+
vertices=prior_shape.v_pos[0].detach().cpu().numpy(),
|
487 |
+
faces=prior_shape.t_pos_idx[0].detach().cpu().numpy(),
|
488 |
+
process=False,
|
489 |
+
maintain_order=True)
|
490 |
+
|
491 |
+
|
492 |
+
|
493 |
+
def run_demo():
|
494 |
+
parser = argparse.ArgumentParser()
|
495 |
+
parser.add_argument('--gpu', default='0', type=str,
|
496 |
+
help='Specify a GPU device')
|
497 |
+
parser.add_argument('--num_workers', default=4, type=int,
|
498 |
+
help='Specify the number of worker threads for data loaders')
|
499 |
+
parser.add_argument('--seed', default=0, type=int,
|
500 |
+
help='Specify a random seed')
|
501 |
+
parser.add_argument('--config', default='./ckpts/configs.yml',
|
502 |
+
type=str) # Model config path
|
503 |
+
parser.add_argument('--checkpoint_path', default='./ckpts/iter0800000.pth', type=str)
|
504 |
+
|
505 |
+
args = parser.parse_args()
|
506 |
+
|
507 |
+
torch.manual_seed(args.seed)
|
508 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
509 |
+
os.environ['MASTER_PORT'] = '8088'
|
510 |
+
dist.init_process_group("gloo", rank=_GPU_ID, world_size=1)
|
511 |
+
torch.cuda.set_device(_GPU_ID)
|
512 |
+
args.rank = _GPU_ID
|
513 |
+
args.world_size = 1
|
514 |
+
args.gpu = os.environ['CUDA_VISIBLE_DEVICES']
|
515 |
+
device = f'cuda:{_GPU_ID}'
|
516 |
+
|
517 |
+
resolution = (256, 256)
|
518 |
+
batch_size = 1
|
519 |
+
model_cfgs = setup_runtime(args)
|
520 |
+
bone_y_thresh = 0.4
|
521 |
+
body_bone_idx_preset = [3, 6, 6, 3]
|
522 |
+
model_cfgs['body_bone_idx_preset'] = body_bone_idx_preset
|
523 |
+
|
524 |
+
model = Unsup3DDDP(model_cfgs)
|
525 |
+
# a hack attempt
|
526 |
+
model.netPrior.classes_vectors = torch.nn.Parameter(torch.nn.init.uniform_(torch.empty(123, 128), a=-0.05, b=0.05))
|
527 |
+
cp = torch.load(args.checkpoint_path, map_location=device)
|
528 |
+
model.load_model_state(cp)
|
529 |
+
memory_bank_keys = cp['memory_bank_keys']
|
530 |
+
memory_bank = cp['memory_bank']
|
531 |
+
|
532 |
+
model.to(device)
|
533 |
+
memory_bank.to(device)
|
534 |
+
memory_bank_keys.to(device)
|
535 |
+
model_items = [
|
536 |
+
model,
|
537 |
+
memory_bank,
|
538 |
+
memory_bank_keys
|
539 |
+
]
|
540 |
+
|
541 |
+
predictor = sam_init()
|
542 |
+
|
543 |
+
custom_theme = gr.themes.Soft(primary_hue="blue").set(
|
544 |
+
button_secondary_background_fill="*neutral_100",
|
545 |
+
button_secondary_background_fill_hover="*neutral_200")
|
546 |
+
custom_css = '''#disp_image {
|
547 |
+
text-align: center; /* Horizontally center the content */
|
548 |
+
}'''
|
549 |
+
|
550 |
+
with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo:
|
551 |
+
with gr.Row():
|
552 |
+
with gr.Column(scale=1):
|
553 |
+
gr.Markdown('# ' + _TITLE)
|
554 |
+
gr.Markdown(_DESCRIPTION)
|
555 |
+
with gr.Row(variant='panel'):
|
556 |
+
with gr.Column(scale=1):
|
557 |
+
input_image = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None)
|
558 |
+
|
559 |
+
example_folder = os.path.join(os.path.dirname(__file__), "./example_images")
|
560 |
+
example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)]
|
561 |
+
gr.Examples(
|
562 |
+
examples=example_fns,
|
563 |
+
inputs=[input_image],
|
564 |
+
# outputs=[input_image],
|
565 |
+
cache_examples=False,
|
566 |
+
label='Examples (click one of the images below to start)',
|
567 |
+
examples_per_page=30
|
568 |
+
)
|
569 |
+
with gr.Column(scale=1):
|
570 |
+
processed_image = gr.Image(type='pil', label="Processed Image", interactive=False, height=256, tool=None, image_mode='RGB', elem_id="disp_image")
|
571 |
+
processed_image_highres = gr.Image(type='pil', image_mode='RGB', visible=False, tool=None)
|
572 |
+
|
573 |
+
with gr.Accordion('Advanced options', open=True):
|
574 |
+
with gr.Row():
|
575 |
+
with gr.Column():
|
576 |
+
input_processing = gr.CheckboxGroup(['Use SAM to center animal'],
|
577 |
+
label='Input Image Preprocessing',
|
578 |
+
value=['Use SAM to center animal'],
|
579 |
+
info='untick this, if animal is already centered, e.g. in example images')
|
580 |
+
# with gr.Column():
|
581 |
+
# output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[])
|
582 |
+
# with gr.Row():
|
583 |
+
# with gr.Column():
|
584 |
+
# scale_slider = gr.Slider(1, 5, value=3, step=1,
|
585 |
+
# label='Classifier Free Guidance Scale')
|
586 |
+
# with gr.Column():
|
587 |
+
# steps_slider = gr.Slider(15, 100, value=50, step=1,
|
588 |
+
# label='Number of Diffusion Inference Steps')
|
589 |
+
# with gr.Row():
|
590 |
+
# with gr.Column():
|
591 |
+
# seed = gr.Number(42, label='Seed')
|
592 |
+
# with gr.Column():
|
593 |
+
# crop_size = gr.Number(192, label='Crop size')
|
594 |
+
# crop_size = 192
|
595 |
+
run_btn = gr.Button('Generate', variant='primary', interactive=True)
|
596 |
+
with gr.Row():
|
597 |
+
view_1 = gr.Image(interactive=False, height=256, show_label=False)
|
598 |
+
view_2 = gr.Image(interactive=False, height=256, show_label=False)
|
599 |
+
with gr.Row():
|
600 |
+
shape_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Reconstructed Model")
|
601 |
+
shape_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Bank Base Shape Model")
|
602 |
+
|
603 |
+
with gr.Row():
|
604 |
+
view_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200)
|
605 |
+
normal_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200)
|
606 |
+
|
607 |
+
|
608 |
+
run_btn.click(fn=partial(preprocess, predictor),
|
609 |
+
inputs=[input_image, input_processing],
|
610 |
+
outputs=[processed_image_highres, processed_image], queue=True
|
611 |
+
).success(fn=partial(run_pipeline, model_items, model_cfgs),
|
612 |
+
inputs=[processed_image, device],
|
613 |
+
outputs=[view_1, view_2, shape_1, shape_2]
|
614 |
+
)
|
615 |
+
demo.queue().launch(share=True, max_threads=80)
|
616 |
+
|
617 |
+
|
618 |
+
if __name__ == '__main__':
|
619 |
+
fire.Fire(run_demo)
|
requirements.txt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip==1.0
|
2 |
+
ConfigArgParse==1.5.3
|
3 |
+
core==1.0.1
|
4 |
+
diffusers==0.20.0
|
5 |
+
einops==0.4.1
|
6 |
+
faiss==1.7.3
|
7 |
+
fire==0.5.0
|
8 |
+
glfw==2.5.7
|
9 |
+
gradio==4.12.0
|
10 |
+
imageio==2.27.0
|
11 |
+
ipdb==0.13.9
|
12 |
+
lpips==0.1.4
|
13 |
+
matplotlib==3.8.1
|
14 |
+
numpy==1.23.1
|
15 |
+
nvdiffrast==0.3.0
|
16 |
+
Pillow==9.2.0
|
17 |
+
Pillow==10.1.0
|
18 |
+
PyOpenGL==3.1.6
|
19 |
+
PyOpenGL==3.1.7
|
20 |
+
pytorch3d==0.7.2
|
21 |
+
PyYAML==6.0
|
22 |
+
PyYAML==6.0.1
|
23 |
+
scipy==1.9.1
|
24 |
+
segment_anything==1.0
|
25 |
+
siren_pytorch==0.1.7
|
26 |
+
tinycudann==1.7
|
27 |
+
torch==1.10.0
|
28 |
+
torchvision==0.11.0
|
29 |
+
transformers==4.28.1
|
30 |
+
trimesh==4.0.0
|
31 |
+
wandb==0.14.2
|
32 |
+
xatlas==0.0.7
|