|
import gradio as gr |
|
from models import build_model |
|
from PIL import Image |
|
import numpy as np |
|
import torchvision |
|
import ninja |
|
import torch |
|
from tqdm import trange |
|
import imageio |
|
|
|
checkpoint = '/mnt/petrelfs/zhangqihang/data/berfscene_clevr.pth' |
|
state = torch.load(checkpoint, map_location='cpu') |
|
G = build_model(**state['model_kwargs_init']['generator_smooth']) |
|
o0, o1 = G.load_state_dict(state['models']['generator_smooth'], strict=False) |
|
G.eval().cuda() |
|
G.backbone.synthesis.input.x_offset =0 |
|
G.backbone.synthesis.input.y_offset =0 |
|
G_kwargs= dict(noise_mode='const', |
|
fused_modulate=False, |
|
impl='cuda', |
|
fp16_res=None) |
|
|
|
def trans(x, y, z, length): |
|
w = h = length |
|
x = 0.5 * w - 128 + 256 - (x/9 + .5) * 256 |
|
y = 0.5 * h - 128 + (y/9 + .5) * 256 |
|
z = z / 9 * 256 |
|
return x, y, z |
|
def get_bev_from_objs(objs, length=256, scale = 6): |
|
h, w = length, length *scale |
|
nc = 14 |
|
canvas = np.zeros([h, w, nc]) |
|
xx = np.ones([h,w]).cumsum(0) |
|
yy = np.ones([h,w]).cumsum(1) |
|
|
|
for x, y, z, shape, color, material, rot in objs: |
|
y, x, z = trans(x, y, z, length) |
|
|
|
feat = [0] * nc |
|
feat[0] = 1 |
|
feat[COLOR_NAME_LIST.index(color) + 1] = 1 |
|
feat[SHAPE_NAME_LIST.index(shape) + 1 + len(COLOR_NAME_LIST)] = 1 |
|
feat[MATERIAL_NAME_LIST.index(material) + 1 + len(COLOR_NAME_LIST) + len(SHAPE_NAME_LIST)] = 1 |
|
feat = np.array(feat) |
|
rot_sin = np.sin(rot / 180 * np.pi) |
|
rot_cos = np.cos(rot / 180 * np.pi) |
|
|
|
if shape == 'cube': |
|
mask = (np.abs(+rot_cos * (xx-x) + rot_sin * (yy-y)) <= z) * \ |
|
(np.abs(-rot_sin * (xx-x) + rot_cos * (yy-y)) <= z) |
|
else: |
|
mask = ((xx-x)**2 + (y-yy)**2) ** 0.5 <= z |
|
canvas[mask] = feat |
|
canvas = np.transpose(canvas, [2, 0, 1]).astype(np.float32) |
|
rotate_angle = 0 |
|
canvas = torchvision.transforms.functional.rotate(torch.tensor(canvas), rotate_angle).numpy() |
|
return canvas |
|
|
|
|
|
COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'purple', 'blue'] |
|
SHAPE_NAME_LIST = ['cube', 'sphere', 'cylinder'] |
|
MATERIAL_NAME_LIST = ['rubber', 'metal'] |
|
|
|
xy_lib = dict() |
|
xy_lib['B'] = [ |
|
[-2, -1], |
|
[-1, -1], |
|
[-2, 0], |
|
[-2, 1], |
|
[-1, .5], |
|
[0, 1], |
|
[0, 0], |
|
[0, -1], |
|
[0, 2], |
|
[-1, 2], |
|
[-2, 2] |
|
] |
|
xy_lib['B'] = [ |
|
[-2.5, 1.25], |
|
[-2, 2], |
|
[-2, 0.5], |
|
[-2, -0.75], |
|
[-1, -1], |
|
[-1, 2], |
|
[-1, 0], |
|
[-1, 2], |
|
[0, 1], |
|
[0, 0], |
|
[0, -1], |
|
[0, 2], |
|
|
|
|
|
] |
|
xy_lib['B'] = [ |
|
[-2.5, 1.25], |
|
[-2, 2], |
|
[-2, 0.5], |
|
[-2, -1], |
|
[-1, -1.25], |
|
[-1, 2], |
|
[-1, 0], |
|
[-1, 2], |
|
[0, 1], |
|
[0, 0], |
|
[0, -1.25], |
|
[0, 2], |
|
|
|
|
|
] |
|
xy_lib['R'] = [ |
|
[0, -1], |
|
[0, 0], |
|
[0, 1], |
|
[0, 2], |
|
[-1, -1], |
|
|
|
[-2, -1], |
|
[-2, 0], |
|
[-2.25, 2], |
|
[-1, 1] |
|
] |
|
xy_lib['C'] = [ |
|
[0, -1], |
|
[0, 0], |
|
[0, 1], |
|
[0, 2], |
|
[-1, -1], |
|
[-1, 2], |
|
[-2, -1], |
|
|
|
[-2, 2], |
|
|
|
] |
|
xy_lib['s'] = [ |
|
[0, -1], |
|
[0, 0], |
|
[0, 2], |
|
[-1, -1], |
|
[-1, 2], |
|
[-2, -1], |
|
[-2, 1], |
|
[-2, 2], |
|
[-1, .5] |
|
] |
|
|
|
xy_lib['F'] = [ |
|
[0, -1], |
|
[0, 0], |
|
[0, 1], |
|
[0, 2], |
|
[-1, -1], |
|
|
|
[-2, -1], |
|
[-2, .5], |
|
|
|
[-1, .5] |
|
] |
|
|
|
xy_lib['c'] = [ |
|
[0.8,1], |
|
|
|
[0,0.1], |
|
[0,1.9], |
|
] |
|
|
|
xy_lib['e'] = [ |
|
[0, -1], |
|
[0, 0], |
|
[0, 1], |
|
[0, 2], |
|
[-1, -1], |
|
[-1, 2], |
|
[-2, -1], |
|
[-2, .5], |
|
[-2, 2], |
|
[-1, .5] |
|
] |
|
xy_lib['n'] = [ |
|
[0,1], |
|
[0,-1], |
|
[0,0.1], |
|
[0,1.9], |
|
[-1,0], |
|
[-2,1], |
|
[-3,-1], |
|
[-3,1], |
|
[-3,0.1], |
|
[-3,1.9], |
|
] |
|
offset_x = dict(B=4, R=4, C=4, F=4, c=3, s=4, e=4, n=4.8) |
|
s = 'BeRFsCene' |
|
objs = [] |
|
offset = 2 |
|
for idx, c in enumerate(s): |
|
xy = xy_lib[c] |
|
|
|
|
|
color = np.random.choice(COLOR_NAME_LIST) |
|
for i in range(len(xy)): |
|
|
|
|
|
|
|
|
|
|
|
x, y = xy[i] |
|
y *= 1.5 |
|
y -= 0.5 |
|
x -= offset |
|
z = 0.35 |
|
|
|
|
|
|
|
|
|
shape = 'cube' |
|
material = 'rubber' |
|
rot = 0 |
|
objs.append([x, y, z, shape, color, material, rot]) |
|
offset += offset_x[c] |
|
Image.fromarray((255 * .8 - get_bev_from_objs(objs)[0] *.8 * 255).astype(np.uint8)) |
|
|
|
batch_size = 1 |
|
code = torch.randn(1, G.z_dim).cuda() |
|
to_pil = torchvision.transforms.ToPILImage() |
|
large_bevs = torch.tensor(get_bev_from_objs(objs)).cuda()[None] |
|
bevs = large_bevs[..., 0: 0+256] |
|
RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, -0.8660, |
|
10.3923, 0.0000, -0.8660, -0.5000, 6.0000, 0.0000, 0.0000, |
|
0.0000, 1.0000, 262.5000, 0.0000, 32.0000, 0.0000, 262.5000, |
|
32.0000, 0.0000, 0.0000, 1.0000]], device='cuda') |
|
|
|
print('prepare finish', flush=True) |
|
|
|
def inference(name): |
|
print('inference', name, flush=True) |
|
gen = G(code, RT, bevs) |
|
rgb = gen['gen_output']['image'][0] * .5 + .5 |
|
print('inference', name, flush=True) |
|
return np.array(to_pil(rgb)) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML( |
|
""" |
|
abc |
|
""") |
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Row(): |
|
num_frames = gr.Dropdown(["24 - frames", "32 - frames", "40 - frames", "48 - frames", "56 - frames", "80 - recommended to run on local GPUs", "240 - recommended to run on local GPUs", "600 - recommended to run on local GPUs", "1200 - recommended to run on local GPUs", "10000 - recommended to run on local GPUs"], label="Number of Video Frames", info="For >56 frames use local workstation!", value="24 - frames") |
|
|
|
with gr.Row(): |
|
with gr.Row(): |
|
btn = gr.Button("Result") |
|
|
|
gallery = gr.Image(label='img', show_label=True, elem_id="gallery") |
|
|
|
btn.click(fn=inference, inputs=num_frames, outputs=[gallery], postprocess=False) |
|
|
|
demo.queue() |
|
demo.launch(server_name='0.0.0.0', server_port=10093, debug=True, show_error=True) |
|
|