BerfScene / app.py
Your Name
init
2f85de4
raw
history blame
No virus
6.61 kB
import gradio as gr
from models import build_model
from PIL import Image
import numpy as np
import torchvision
import ninja
import torch
from tqdm import trange
import imageio
checkpoint = '/mnt/petrelfs/zhangqihang/data/berfscene_clevr.pth'
state = torch.load(checkpoint, map_location='cpu')
G = build_model(**state['model_kwargs_init']['generator_smooth'])
o0, o1 = G.load_state_dict(state['models']['generator_smooth'], strict=False)
G.eval().cuda()
G.backbone.synthesis.input.x_offset =0
G.backbone.synthesis.input.y_offset =0
G_kwargs= dict(noise_mode='const',
fused_modulate=False,
impl='cuda',
fp16_res=None)
def trans(x, y, z, length):
w = h = length
x = 0.5 * w - 128 + 256 - (x/9 + .5) * 256
y = 0.5 * h - 128 + (y/9 + .5) * 256
z = z / 9 * 256
return x, y, z
def get_bev_from_objs(objs, length=256, scale = 6):
h, w = length, length *scale
nc = 14
canvas = np.zeros([h, w, nc])
xx = np.ones([h,w]).cumsum(0)
yy = np.ones([h,w]).cumsum(1)
for x, y, z, shape, color, material, rot in objs:
y, x, z = trans(x, y, z, length)
feat = [0] * nc
feat[0] = 1
feat[COLOR_NAME_LIST.index(color) + 1] = 1
feat[SHAPE_NAME_LIST.index(shape) + 1 + len(COLOR_NAME_LIST)] = 1
feat[MATERIAL_NAME_LIST.index(material) + 1 + len(COLOR_NAME_LIST) + len(SHAPE_NAME_LIST)] = 1
feat = np.array(feat)
rot_sin = np.sin(rot / 180 * np.pi)
rot_cos = np.cos(rot / 180 * np.pi)
if shape == 'cube':
mask = (np.abs(+rot_cos * (xx-x) + rot_sin * (yy-y)) <= z) * \
(np.abs(-rot_sin * (xx-x) + rot_cos * (yy-y)) <= z)
else:
mask = ((xx-x)**2 + (y-yy)**2) ** 0.5 <= z
canvas[mask] = feat
canvas = np.transpose(canvas, [2, 0, 1]).astype(np.float32)
rotate_angle = 0
canvas = torchvision.transforms.functional.rotate(torch.tensor(canvas), rotate_angle).numpy()
return canvas
# COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'brown', 'blue']
COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'purple', 'blue']
SHAPE_NAME_LIST = ['cube', 'sphere', 'cylinder']
MATERIAL_NAME_LIST = ['rubber', 'metal']
xy_lib = dict()
xy_lib['B'] = [
[-2, -1],
[-1, -1],
[-2, 0],
[-2, 1],
[-1, .5],
[0, 1],
[0, 0],
[0, -1],
[0, 2],
[-1, 2],
[-2, 2]
]
xy_lib['B'] = [
[-2.5, 1.25],
[-2, 2],
[-2, 0.5],
[-2, -0.75],
[-1, -1],
[-1, 2],
[-1, 0],
[-1, 2],
[0, 1],
[0, 0],
[0, -1],
[0, 2],
# [-1, 2],
]
xy_lib['B'] = [
[-2.5, 1.25],
[-2, 2],
[-2, 0.5],
[-2, -1],
[-1, -1.25],
[-1, 2],
[-1, 0],
[-1, 2],
[0, 1],
[0, 0],
[0, -1.25],
[0, 2],
# [-1, 2],
]
xy_lib['R'] = [
[0, -1],
[0, 0],
[0, 1],
[0, 2],
[-1, -1],
# [-1, 2],
[-2, -1],
[-2, 0],
[-2.25, 2],
[-1, 1]
]
xy_lib['C'] = [
[0, -1],
[0, 0],
[0, 1],
[0, 2],
[-1, -1],
[-1, 2],
[-2, -1],
# [-2, .5],
[-2, 2],
# [-1, .5]
]
xy_lib['s'] = [
[0, -1],
[0, 0],
[0, 2],
[-1, -1],
[-1, 2],
[-2, -1],
[-2, 1],
[-2, 2],
[-1, .5]
]
xy_lib['F'] = [
[0, -1],
[0, 0],
[0, 1],
[0, 2],
[-1, -1],
# [-1, 2],
[-2, -1],
[-2, .5],
# [-2, 2],
[-1, .5]
]
xy_lib['c'] = [
[0.8,1],
# [-0.8,1],
[0,0.1],
[0,1.9],
]
xy_lib['e'] = [
[0, -1],
[0, 0],
[0, 1],
[0, 2],
[-1, -1],
[-1, 2],
[-2, -1],
[-2, .5],
[-2, 2],
[-1, .5]
]
xy_lib['n'] = [
[0,1],
[0,-1],
[0,0.1],
[0,1.9],
[-1,0],
[-2,1],
[-3,-1],
[-3,1],
[-3,0.1],
[-3,1.9],
]
offset_x = dict(B=4, R=4, C=4, F=4, c=3, s=4, e=4, n=4.8)
s = 'BeRFsCene'
objs = []
offset = 2
for idx, c in enumerate(s):
xy = xy_lib[c]
color = np.random.choice(COLOR_NAME_LIST)
for i in range(len(xy)):
# while 1:
# is_ok = 1
# x, y =
# for prev_x, prev_y in zip(xpool, ypool):
x, y = xy[i]
y *= 1.5
y -= 0.5
x -= offset
z = 0.35
# if idx<4:
# color = np.random.choice(COLOR_NAME_LIST[:-1])
# else:
# color = 'blue'
shape = 'cube'
material = 'rubber'
rot = 0
objs.append([x, y, z, shape, color, material, rot])
offset += offset_x[c]
Image.fromarray((255 * .8 - get_bev_from_objs(objs)[0] *.8 * 255).astype(np.uint8))
batch_size = 1
code = torch.randn(1, G.z_dim).cuda()
to_pil = torchvision.transforms.ToPILImage()
large_bevs = torch.tensor(get_bev_from_objs(objs)).cuda()[None]
bevs = large_bevs[..., 0: 0+256]
RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, -0.8660,
10.3923, 0.0000, -0.8660, -0.5000, 6.0000, 0.0000, 0.0000,
0.0000, 1.0000, 262.5000, 0.0000, 32.0000, 0.0000, 262.5000,
32.0000, 0.0000, 0.0000, 1.0000]], device='cuda')
print('prepare finish', flush=True)
def inference(name):
print('inference', name, flush=True)
gen = G(code, RT, bevs)
rgb = gen['gen_output']['image'][0] * .5 + .5
print('inference', name, flush=True)
return np.array(to_pil(rgb))
# to_pil(rgb).save('tmp.png')
# save_path = '/mnt/petrelfs/zhangqihang/code/3d-scene-gen/tmp.png'
# return [save_path]
with gr.Blocks() as demo:
gr.HTML(
"""
abc
""")
with gr.Group():
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
with gr.Row():
num_frames = gr.Dropdown(["24 - frames", "32 - frames", "40 - frames", "48 - frames", "56 - frames", "80 - recommended to run on local GPUs", "240 - recommended to run on local GPUs", "600 - recommended to run on local GPUs", "1200 - recommended to run on local GPUs", "10000 - recommended to run on local GPUs"], label="Number of Video Frames", info="For >56 frames use local workstation!", value="24 - frames")
with gr.Row():
with gr.Row():
btn = gr.Button("Result")
gallery = gr.Image(label='img', show_label=True, elem_id="gallery")
btn.click(fn=inference, inputs=num_frames, outputs=[gallery], postprocess=False)
demo.queue()
demo.launch(server_name='0.0.0.0', server_port=10093, debug=True, show_error=True)