Spaces:
Sleeping
Sleeping
import os | |
import fire | |
import gradio as gr | |
from PIL import Image | |
from functools import partial | |
import argparse | |
os.system('pip install --global-option="--no-networks" git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch') | |
os.system('pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"') | |
import cv2 | |
import time | |
import numpy as np | |
import trimesh | |
from segment_anything import sam_model_registry, SamPredictor | |
import random | |
from pytorch3d import transforms | |
import torch | |
import torchvision | |
import torch.distributed as dist | |
import nvdiffrast.torch as dr | |
from video3d.model_ddp import Unsup3DDDP, forward_to_matrix | |
from video3d.trainer_few_shot import Fewshot_Trainer | |
from video3d.trainer_ddp import TrainerDDP | |
from video3d import setup_runtime | |
from video3d.render.mesh import make_mesh | |
from video3d.utils.skinning_v4 import estimate_bones, skinning, euler_angles_to_matrix | |
from video3d.utils.misc import save_obj | |
from video3d.render import util | |
import matplotlib.pyplot as plt | |
from pytorch3d import utils, renderer, transforms, structures, io | |
from video3d.render.render import render_mesh | |
from video3d.render.material import texture as material_texture | |
_TITLE = '''Learning the 3D Fauna of the Web''' | |
_DESCRIPTION = ''' | |
<div> | |
Reconstruct any quadruped animal from one image. | |
</div> | |
<div> | |
The demo only contains the 3D reconstruction part. | |
</div> | |
''' | |
_GPU_ID = 0 | |
if not hasattr(Image, 'Resampling'): | |
Image.Resampling = Image | |
def sam_init(): | |
sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_h_4b8939.pth") | |
model_type = "vit_h" | |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=f"cuda:{_GPU_ID}") | |
predictor = SamPredictor(sam) | |
return predictor | |
def sam_segment(predictor, input_image, *bbox_coords): | |
bbox = np.array(bbox_coords) | |
image = np.asarray(input_image) | |
start_time = time.time() | |
predictor.set_image(image) | |
masks_bbox, scores_bbox, logits_bbox = predictor.predict( | |
box=bbox, | |
multimask_output=True | |
) | |
print(f"SAM Time: {time.time() - start_time:.3f}s") | |
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) | |
out_image[:, :, :3] = image | |
out_image_bbox = out_image.copy() | |
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 | |
torch.cuda.empty_cache() | |
return Image.fromarray(out_image_bbox, mode='RGB') | |
# return Image.fromarray(out_image_bbox, mode='RGBA') | |
def expand2square(pil_img, background_color): | |
width, height = pil_img.size | |
if width == height: | |
return pil_img | |
elif width > height: | |
result = Image.new(pil_img.mode, (width, width), background_color) | |
result.paste(pil_img, (0, (width - height) // 2)) | |
return result | |
else: | |
result = Image.new(pil_img.mode, (height, height), background_color) | |
result.paste(pil_img, ((height - width) // 2, 0)) | |
return result | |
def preprocess(predictor, input_image, chk_group=None, segment=True): | |
RES = 1024 | |
input_image.thumbnail([RES, RES], Image.Resampling.LANCZOS) | |
if chk_group is not None: | |
segment = "Use SAM to center animal" in chk_group | |
if segment: | |
image_rem = input_image.convert('RGB') | |
arr = np.asarray(image_rem)[:,:,-1] | |
x_nonzero = np.nonzero(arr.sum(axis=0)) | |
y_nonzero = np.nonzero(arr.sum(axis=1)) | |
x_min = int(x_nonzero[0].min()) | |
y_min = int(y_nonzero[0].min()) | |
x_max = int(x_nonzero[0].max()) | |
y_max = int(y_nonzero[0].max()) | |
input_image = sam_segment(predictor, input_image.convert('RGB'), x_min, y_min, x_max, y_max) | |
# Rescale and recenter | |
# if rescale: | |
# image_arr = np.array(input_image) | |
# in_w, in_h = image_arr.shape[:2] | |
# out_res = min(RES, max(in_w, in_h)) | |
# ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY) | |
# x, y, w, h = cv2.boundingRect(mask) | |
# max_size = max(w, h) | |
# ratio = 0.75 | |
# side_len = int(max_size / ratio) | |
# padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) | |
# center = side_len//2 | |
# padded_image[center-h//2:center-h//2+h, center-w//2:center-w//2+w] = image_arr[y:y+h, x:x+w] | |
# rgba = Image.fromarray(padded_image).resize((out_res, out_res), Image.LANCZOS) | |
# rgba_arr = np.array(rgba) / 255.0 | |
# rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:]) | |
# input_image = Image.fromarray((rgb * 255).astype(np.uint8)) | |
# else: | |
# input_image = expand2square(input_image, (127, 127, 127, 0)) | |
input_image = expand2square(input_image, (0, 0, 0)) | |
return input_image, input_image.resize((256, 256), Image.Resampling.LANCZOS) | |
def save_images(images, mask_pred, mode="transparent"): | |
img = images[0] | |
mask = mask_pred[0] | |
img = img.clamp(0, 1) | |
if mask is not None: | |
mask = mask.clamp(0, 1) | |
if mode == "white": | |
img = img * mask + 1 * (1 - mask) | |
elif mode == "black": | |
img = img * mask + 0 * (1 - mask) | |
else: | |
img = torch.cat([img, mask[0:1]], 0) | |
img = img.permute(1, 2, 0).cpu().numpy() | |
img = Image.fromarray(np.uint8(img * 255)) | |
return img | |
def get_bank_embedding(rgb, memory_bank_keys, memory_bank, model, memory_bank_topk=10, memory_bank_dim=128): | |
images = rgb | |
batch_size, num_frames, _, h0, w0 = images.shape | |
images = images.reshape(batch_size*num_frames, *images.shape[2:]) # 0~1 | |
images_in = images * 2 - 1 # rescale to (-1, 1) for DINO | |
x = images_in | |
with torch.no_grad(): | |
b, c, h, w = x.shape | |
model.netInstance.netEncoder._feats = [] | |
model.netInstance.netEncoder._register_hooks([11], 'key') | |
#self._register_hooks([11], 'token') | |
x = model.netInstance.netEncoder.ViT.prepare_tokens(x) | |
#x = self.ViT.prepare_tokens_with_masks(x) | |
for blk in model.netInstance.netEncoder.ViT.blocks: | |
x = blk(x) | |
out = model.netInstance.netEncoder.ViT.norm(x) | |
model.netInstance.netEncoder._unregister_hooks() | |
ph, pw = h // model.netInstance.netEncoder.patch_size, w // model.netInstance.netEncoder.patch_size | |
patch_out = out[:, 1:] # first is class token | |
patch_out = patch_out.reshape(b, ph, pw, model.netInstance.netEncoder.vit_feat_dim).permute(0, 3, 1, 2) | |
patch_key = model.netInstance.netEncoder._feats[0][:,:,1:] # B, num_heads, num_patches, dim | |
patch_key = patch_key.permute(0, 1, 3, 2).reshape(b, model.netInstance.netEncoder.vit_feat_dim, ph, pw) | |
global_feat = out[:, 0] | |
batch_features = global_feat | |
batch_size = batch_features.shape[0] | |
query = torch.nn.functional.normalize(batch_features.unsqueeze(1), dim=-1) # [B, 1, d_k] | |
key = torch.nn.functional.normalize(memory_bank_keys, dim=-1) # [size, d_k] | |
key = key.transpose(1, 0).unsqueeze(0).repeat(batch_size, 1, 1).to(query.device) # [B, d_k, size] | |
cos_dist = torch.bmm(query, key).squeeze(1) # [B, size], larger the more similar | |
rank_idx = torch.sort(cos_dist, dim=-1, descending=True)[1][:, :memory_bank_topk] # [B, k] | |
value = memory_bank.unsqueeze(0).repeat(batch_size, 1, 1).to(query.device) # [B, size, d_v] | |
out = torch.gather(value, dim=1, index=rank_idx[..., None].repeat(1, 1, memory_bank_dim)) # [B, k, d_v] | |
weights = torch.gather(cos_dist, dim=-1, index=rank_idx) # [B, k] | |
weights = torch.nn.functional.normalize(weights, p=1.0, dim=-1).unsqueeze(-1).repeat(1, 1, memory_bank_dim) # [B, k, d_v] weights have been normalized | |
out = weights * out | |
out = torch.sum(out, dim=1) | |
batch_mean_out = torch.mean(out, dim=0) | |
weight_aux = { | |
'weights': weights[:, :, 0], # [B, k], weights from large to small | |
'pick_idx': rank_idx, # [B, k] | |
} | |
batch_embedding = batch_mean_out | |
embeddings = out | |
weights = weight_aux | |
bank_embedding_model_input = [batch_embedding, embeddings, weights] | |
return bank_embedding_model_input | |
class FixedDirectionLight(torch.nn.Module): | |
def __init__(self, direction, amb, diff): | |
super(FixedDirectionLight, self).__init__() | |
self.light_dir = direction | |
self.amb = amb | |
self.diff = diff | |
self.is_hacking = not (isinstance(self.amb, float) | |
or isinstance(self.amb, int)) | |
def forward(self, feat): | |
batch_size = feat.shape[0] | |
if self.is_hacking: | |
return torch.concat([self.light_dir, self.amb, self.diff], -1) | |
else: | |
return torch.concat([self.light_dir, torch.FloatTensor([self.amb, self.diff]).to(self.light_dir.device)], -1).expand(batch_size, -1) | |
def shade(self, feat, kd, normal): | |
light_params = self.forward(feat) | |
light_dir = light_params[..., :3][:, None, None, :] | |
int_amb = light_params[..., 3:4][:, None, None, :] | |
int_diff = light_params[..., 4:5][:, None, None, :] | |
shading = (int_amb + int_diff * | |
torch.clamp(util.dot(light_dir, normal), min=0.0)) | |
shaded = shading * kd | |
return shaded, shading | |
def render_bones(mvp, bones_pred, size=(256, 256)): | |
bone_world4 = torch.concat([bones_pred, torch.ones_like(bones_pred[..., :1]).to(bones_pred.device)], dim=-1) | |
b, f, num_bones = bone_world4.shape[:3] | |
bones_clip4 = (bone_world4.view(b, f, num_bones*2, 1, 4) @ mvp.transpose(-1, -2).reshape(b, f, 1, 4, 4)).view(b, f, num_bones, 2, 4) | |
bones_uv = bones_clip4[..., :2] / bones_clip4[..., 3:4] # b, f, num_bones, 2, 2 | |
dpi = 32 | |
fx, fy = size[1] // dpi, size[0] // dpi | |
rendered = [] | |
for b_idx in range(b): | |
for f_idx in range(f): | |
frame_bones_uv = bones_uv[b_idx, f_idx].cpu().numpy() | |
fig = plt.figure(figsize=(fx, fy), dpi=dpi, frameon=False) | |
ax = plt.Axes(fig, [0., 0., 1., 1.]) | |
ax.set_axis_off() | |
for bone in frame_bones_uv: | |
ax.plot(bone[:, 0], bone[:, 1], marker='o', linewidth=8, markersize=20) | |
ax.set_xlim(-1, 1) | |
ax.set_ylim(-1, 1) | |
ax.invert_yaxis() | |
# Convert to image | |
fig.add_axes(ax) | |
fig.canvas.draw_idle() | |
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
w, h = fig.canvas.get_width_height() | |
image.resize(h, w, 3) | |
rendered += [image / 255.] | |
return torch.from_numpy(np.stack(rendered, 0).transpose(0, 3, 1, 2)).to(bones_pred.device) | |
def add_mesh_color(mesh, color): | |
verts = mesh.verts_padded() | |
color = torch.FloatTensor(color).to(verts.device).view(1,1,3) / 255 | |
mesh.textures = renderer.TexturesVertex(verts_features=verts*0+color) | |
return mesh | |
def create_sphere(position, scale, device, color=[139, 149, 173]): | |
mesh = utils.ico_sphere(2).to(device) | |
mesh = mesh.extend(position.shape[0]) | |
# scale and offset | |
mesh = mesh.update_padded(mesh.verts_padded() * scale + position[:, None]) | |
mesh = add_mesh_color(mesh, color) | |
return mesh | |
def estimate_bone_rotation(b): | |
""" | |
(0, 0, 1) = matmul(R^(-1), b) | |
assumes x, y is a symmetry plane | |
returns R | |
""" | |
b = b / torch.norm(b, dim=-1, keepdim=True) | |
n = torch.FloatTensor([[1, 0, 0]]).to(b.device) | |
n = n.expand_as(b) | |
v = torch.cross(b, n, dim=-1) | |
R = torch.stack([n, v, b], dim=-1).transpose(-2, -1) | |
return R | |
def estimate_vector_rotation(vector_a, vector_b): | |
""" | |
vector_a = matmul(R, vector_b) | |
returns R | |
https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d | |
""" | |
vector_a = vector_a / torch.norm(vector_a, dim=-1, keepdim=True) | |
vector_b = vector_b / torch.norm(vector_b, dim=-1, keepdim=True) | |
v = torch.cross(vector_a, vector_b, dim=-1) | |
c = torch.sum(vector_a * vector_b, dim=-1) | |
skew = torch.stack([ | |
torch.stack([torch.zeros_like(v[..., 0]), -v[..., 2], v[..., 1]], dim=-1), | |
torch.stack([v[..., 2], torch.zeros_like(v[..., 0]), -v[..., 0]], dim=-1), | |
torch.stack([-v[..., 1], v[..., 0], torch.zeros_like(v[..., 0])], dim=-1)], | |
dim=-1) | |
R = torch.eye(3, device=vector_a.device)[None] + skew + torch.matmul(skew, skew) / (1 + c[..., None, None]) | |
return R | |
def create_elipsoid(bone, scale=0.05, color=[139, 149, 173], generic_rotation_estim=True): | |
length = torch.norm(bone[:, 0] - bone[:, 1], dim=-1) | |
mesh = utils.ico_sphere(2).to(bone.device) | |
mesh = mesh.extend(bone.shape[0]) | |
# scale x, y | |
verts = mesh.verts_padded() * torch.FloatTensor([scale, scale, 1]).to(bone.device) | |
# stretch along z axis, set the start to origin | |
verts[:, :, 2] = verts[:, :, 2] * length[:, None] * 0.5 + length[:, None] * 0.5 | |
bone_vector = bone[:, 1] - bone[:, 0] | |
z_vector = torch.FloatTensor([[0, 0, 1]]).to(bone.device) | |
z_vector = z_vector.expand_as(bone_vector) | |
if generic_rotation_estim: | |
rot = estimate_vector_rotation(z_vector, bone_vector) | |
else: | |
rot = estimate_bone_rotation(bone_vector) | |
tsf = transforms.Rotate(rot, device=bone.device) | |
tsf = tsf.compose(transforms.Translate(bone[:, 0], device=bone.device)) | |
verts = tsf.transform_points(verts) | |
mesh = mesh.update_padded(verts) | |
mesh = add_mesh_color(mesh, color) | |
return mesh | |
def convert_textures_vertex_to_textures_uv(meshes: structures.Meshes, color1, color2) -> renderer.TexturesUV: | |
""" | |
Convert a TexturesVertex object to a TexturesUV object. | |
""" | |
color1 = torch.Tensor(color1).to(meshes.device).view(1, 1, 3) / 255 | |
color2 = torch.Tensor(color2).to(meshes.device).view(1, 1, 3) / 255 | |
textures_vertex = meshes.textures | |
assert isinstance(textures_vertex, renderer.TexturesVertex), "Input meshes must have TexturesVertex" | |
verts_rgb = textures_vertex.verts_features_padded() | |
faces_uvs = meshes.faces_padded() | |
batch_size = verts_rgb.shape[0] | |
maps = torch.zeros(batch_size, 128, 128, 3, device=verts_rgb.device) | |
maps[:, :, :64, :] = color1 | |
maps[:, :, 64:, :] = color2 | |
is_first = (verts_rgb == color1)[..., 0] | |
verts_uvs = torch.zeros(batch_size, verts_rgb.shape[1], 2, device=verts_rgb.device) | |
verts_uvs[is_first] = torch.FloatTensor([0.25, 0.5]).to(verts_rgb.device) | |
verts_uvs[~is_first] = torch.FloatTensor([0.75, 0.5]).to(verts_rgb.device) | |
textures_uv = renderer.TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs) | |
meshes.textures = textures_uv | |
return meshes | |
def create_bones_scene(bones, joint_color=[66, 91, 140], bone_color=[119, 144, 189], show_end_point=False): | |
meshes = [] | |
for bone_i in range(bones.shape[1]): | |
# points | |
meshes += [create_sphere(bones[:, bone_i, 0], 0.1, bones.device, color=joint_color)] | |
if show_end_point: | |
meshes += [create_sphere(bones[:, bone_i, 1], 0.1, bones.device, color=joint_color)] | |
# connecting ellipsoid | |
meshes += [create_elipsoid(bones[:, bone_i], color=bone_color)] | |
current_batch_size = bones.shape[0] | |
meshes = [structures.join_meshes_as_scene([m[i] for m in meshes]) for i in range(current_batch_size)] | |
mesh = structures.join_meshes_as_batch(meshes) | |
return mesh | |
def run_pipeline(model_items, cfgs, input_img, device): | |
epoch = 999 | |
total_iter = 999999 | |
model = model_items[0] | |
memory_bank = model_items[1] | |
memory_bank_keys = model_items[2] | |
input_image = torch.stack([torchvision.transforms.ToTensor()(input_img)], dim=0).to(device) | |
with torch.no_grad(): | |
model.netPrior.eval() | |
model.netInstance.eval() | |
input_image = torch.nn.functional.interpolate(input_image, size=(256, 256), mode='bilinear', align_corners=False) | |
input_image = input_image[:, None, :, :] # [B=1, F=1, 3, 256, 256] | |
bank_embedding = get_bank_embedding( | |
input_image, | |
memory_bank_keys, | |
memory_bank, | |
model, | |
memory_bank_topk=cfgs.get("memory_bank_topk", 10), | |
memory_bank_dim=128 | |
) | |
prior_shape, dino_pred, classes_vectors = model.netPrior( | |
category_name='tmp', | |
perturb_sdf=False, | |
total_iter=total_iter, | |
is_training=False, | |
class_embedding=bank_embedding | |
) | |
Instance_out = model.netInstance( | |
'tmp', | |
input_image, | |
prior_shape, | |
epoch, | |
dino_features=None, | |
dino_clusters=None, | |
total_iter=total_iter, | |
is_training=False | |
) # frame dim collapsed N=(B*F) | |
if len(Instance_out) == 13: | |
shape, pose_raw, pose, mvp, w2c, campos, texture_pred, im_features, dino_feat_im_calc, deform, all_arti_params, light, forward_aux = Instance_out | |
im_features_map = None | |
else: | |
shape, pose_raw, pose, mvp, w2c, campos, texture_pred, im_features, dino_feat_im_calc, deform, all_arti_params, light, forward_aux, im_features_map = Instance_out | |
class_vector = classes_vectors # the bank embeddings | |
gray_light = FixedDirectionLight(direction=torch.FloatTensor([0, 0, 1]).to(device), amb=0.2, diff=0.7) | |
image_pred, mask_pred, _, _, _, shading = model.render( | |
shape, texture_pred, mvp, w2c, campos, 256, background=model.background_mode, | |
im_features=im_features, light=gray_light, prior_shape=prior_shape, render_mode='diffuse', | |
render_flow=False, dino_pred=None, im_features_map=im_features_map | |
) | |
mask_pred = mask_pred.expand_as(image_pred) | |
shading = shading.expand_as(image_pred) | |
# render bones in pytorch3D style | |
posed_bones = forward_aux["posed_bones"].squeeze(1) | |
jc, bc = [66, 91, 140], [119, 144, 189] | |
bones_meshes = create_bones_scene(posed_bones, joint_color=jc, bone_color=bc, show_end_point=True) | |
bones_meshes = convert_textures_vertex_to_textures_uv(bones_meshes, color1=jc, color2=bc) | |
nv_meshes = make_mesh(verts=bones_meshes.verts_padded(), faces=bones_meshes.faces_padded()[0:1], | |
uvs=bones_meshes.textures.verts_uvs_padded(), uv_idx=bones_meshes.textures.faces_uvs_padded()[0:1], | |
material=material_texture.Texture2D(bones_meshes.textures.maps_padded())) | |
buffers = render_mesh(dr.RasterizeGLContext(), nv_meshes, mvp, w2c, campos, nv_meshes.material, lgt=gray_light, feat=im_features, dino_pred=None, resolution=256, bsdf="diffuse") | |
shaded = buffers["shaded"].permute(0, 3, 1, 2) | |
bone_image = shaded[:, :3, :, :] | |
bone_mask = shaded[:, 3:, :, :] | |
mask_final = mask_pred.logical_or(bone_mask) | |
mask_final = mask_final.int() | |
image_with_bones = bone_image * bone_mask * 0.5 + (shading * (1 - bone_mask * 0.5) + 0.5 * (mask_final.float() - mask_pred.float())) | |
mesh_image = save_images(shading, mask_pred) | |
mesh_bones_image = save_images(image_with_bones, mask_final) | |
final_shape = shape.clone() | |
prior_shape = prior_shape.clone() | |
final_mesh_tri = trimesh.Trimesh( | |
vertices=final_shape.v_pos[0].detach().cpu().numpy(), | |
faces=final_shape.t_pos_idx[0].detach().cpu().numpy(), | |
process=False, | |
maintain_order=True) | |
prior_mesh_tri = trimesh.Trimesh( | |
vertices=prior_shape.v_pos[0].detach().cpu().numpy(), | |
faces=prior_shape.t_pos_idx[0].detach().cpu().numpy(), | |
process=False, | |
maintain_order=True) | |
def run_demo(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--gpu', default='0', type=str, | |
help='Specify a GPU device') | |
parser.add_argument('--num_workers', default=4, type=int, | |
help='Specify the number of worker threads for data loaders') | |
parser.add_argument('--seed', default=0, type=int, | |
help='Specify a random seed') | |
parser.add_argument('--config', default='./ckpts/configs.yml', | |
type=str) # Model config path | |
parser.add_argument('--checkpoint_path', default='./ckpts/iter0800000.pth', type=str) | |
args = parser.parse_args() | |
torch.manual_seed(args.seed) | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '8088' | |
dist.init_process_group("gloo", rank=_GPU_ID, world_size=1) | |
torch.cuda.set_device(_GPU_ID) | |
args.rank = _GPU_ID | |
args.world_size = 1 | |
args.gpu = os.environ['CUDA_VISIBLE_DEVICES'] | |
device = f'cuda:{_GPU_ID}' | |
resolution = (256, 256) | |
batch_size = 1 | |
model_cfgs = setup_runtime(args) | |
bone_y_thresh = 0.4 | |
body_bone_idx_preset = [3, 6, 6, 3] | |
model_cfgs['body_bone_idx_preset'] = body_bone_idx_preset | |
model = Unsup3DDDP(model_cfgs) | |
# a hack attempt | |
model.netPrior.classes_vectors = torch.nn.Parameter(torch.nn.init.uniform_(torch.empty(123, 128), a=-0.05, b=0.05)) | |
cp = torch.load(args.checkpoint_path, map_location=device) | |
model.load_model_state(cp) | |
memory_bank_keys = cp['memory_bank_keys'] | |
memory_bank = cp['memory_bank'] | |
model.to(device) | |
memory_bank.to(device) | |
memory_bank_keys.to(device) | |
model_items = [ | |
model, | |
memory_bank, | |
memory_bank_keys | |
] | |
predictor = sam_init() | |
custom_theme = gr.themes.Soft(primary_hue="blue").set( | |
button_secondary_background_fill="*neutral_100", | |
button_secondary_background_fill_hover="*neutral_200") | |
custom_css = '''#disp_image { | |
text-align: center; /* Horizontally center the content */ | |
}''' | |
with gr.Blocks(title=_TITLE, theme=custom_theme, css=custom_css) as demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown('# ' + _TITLE) | |
gr.Markdown(_DESCRIPTION) | |
with gr.Row(variant='panel'): | |
with gr.Column(scale=1): | |
input_image = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None) | |
example_folder = os.path.join(os.path.dirname(__file__), "./example_images") | |
example_fns = [os.path.join(example_folder, example) for example in os.listdir(example_folder)] | |
gr.Examples( | |
examples=example_fns, | |
inputs=[input_image], | |
# outputs=[input_image], | |
cache_examples=False, | |
label='Examples (click one of the images below to start)', | |
examples_per_page=30 | |
) | |
with gr.Column(scale=1): | |
processed_image = gr.Image(type='pil', label="Processed Image", interactive=False, height=256, tool=None, image_mode='RGB', elem_id="disp_image") | |
processed_image_highres = gr.Image(type='pil', image_mode='RGB', visible=False, tool=None) | |
with gr.Accordion('Advanced options', open=True): | |
with gr.Row(): | |
with gr.Column(): | |
input_processing = gr.CheckboxGroup(['Use SAM to center animal'], | |
label='Input Image Preprocessing', | |
value=['Use SAM to center animal'], | |
info='untick this, if animal is already centered, e.g. in example images') | |
# with gr.Column(): | |
# output_processing = gr.CheckboxGroup(['Background Removal'], label='Output Image Postprocessing', value=[]) | |
# with gr.Row(): | |
# with gr.Column(): | |
# scale_slider = gr.Slider(1, 5, value=3, step=1, | |
# label='Classifier Free Guidance Scale') | |
# with gr.Column(): | |
# steps_slider = gr.Slider(15, 100, value=50, step=1, | |
# label='Number of Diffusion Inference Steps') | |
# with gr.Row(): | |
# with gr.Column(): | |
# seed = gr.Number(42, label='Seed') | |
# with gr.Column(): | |
# crop_size = gr.Number(192, label='Crop size') | |
# crop_size = 192 | |
run_btn = gr.Button('Generate', variant='primary', interactive=True) | |
with gr.Row(): | |
view_1 = gr.Image(interactive=False, height=256, show_label=False) | |
view_2 = gr.Image(interactive=False, height=256, show_label=False) | |
with gr.Row(): | |
shape_1 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Reconstructed Model") | |
shape_2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="Bank Base Shape Model") | |
with gr.Row(): | |
view_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200) | |
normal_gallery = gr.Gallery(interactive=False, show_label=False, container=True, preview=True, allow_preview=False, height=1200) | |
run_btn.click(fn=partial(preprocess, predictor), | |
inputs=[input_image, input_processing], | |
outputs=[processed_image_highres, processed_image], queue=True | |
).success(fn=partial(run_pipeline, model_items, model_cfgs), | |
inputs=[processed_image, device], | |
outputs=[view_1, view_2, shape_1, shape_2] | |
) | |
demo.queue().launch(share=True, max_threads=80) | |
if __name__ == '__main__': | |
fire.Fire(run_demo) |