import numpy as np import gradio as gr import os from PIL import Image from functools import partial def retrieve_input_image_wild(dataset, inputs): img_id = inputs img_path = os.path.join('online_demo', dataset, 'step-100_scale-6.0') try: image = Image.open(os.path.join(img_path, '%s.jpg' % img_id)) except: image = Image.open(os.path.join(img_path, '%s.png' % img_id)) image.thumbnail([256, 256], Image.Resampling.LANCZOS) return image def retrieve_input_image(dataset, inputs): img_id = inputs img_path = os.path.join('online_demo', dataset, 'step-100_scale-6.0', img_id, 'input.png') image = Image.open(img_path) return image def retrieve_novel_view(dataset, img_id, polar, azimuth, zoom, seed): polar = polar // 30 + 1 azimuth = azimuth // 30 zoom = int(zoom * 2 + 1) img_path = os.path.join('online_demo', dataset, 'step-100_scale-6.0', img_id,\ 'polar-%d_azimuth-%d_distance-%d_seed-%d.png' % (polar, azimuth, zoom, seed)) image = Image.open(img_path) return image with gr.Blocks() as demo: # gr.Markdown("Stable Diffusion Novel View Synthesis (Precomputed Results)") with gr.Tab("In-the-wild Images"): with gr.Row(): with gr.Column(scale=1): default_input_image = Image.open( os.path.join('online_demo', 'nerf_wild', 'step-100_scale-6.0', 'car1.png')) default_input_image.thumbnail([256, 256], Image.Resampling.LANCZOS) input_image = gr.Image(default_input_image, shape=[256, 256]) options = sorted(next(os.walk('online_demo/nerf_wild/step-100_scale-6.0'))[1]) img_id = gr.Dropdown(options, value='car1', label='options') text_button = gr.Button("Load Input Image") retrieve_input_image_dataset = partial(retrieve_input_image_wild, 'nerf_wild') text_button.click(retrieve_input_image_dataset, inputs=img_id, outputs=input_image) with gr.Column(scale=1): novel_view = gr.Image(shape=[256, 256]) inputs = [img_id, gr.Slider(-30, 30, value=0, step=30, label='Polar angle (vertical rotation in degrees)'), gr.Slider(0, 330, value=0, step=30, label='Azimuth angle (horizontal rotation in degrees)'), gr.Slider(-0.5, 0.5, value=0, step=0.5, label='Zoom'), gr.Slider(0, 3, value=1, step=1, label='Random seed')] submit_button = gr.Button("Generate Novel View") retrieve_novel_view_dataset = partial(retrieve_novel_view, 'nerf_wild') submit_button.click(retrieve_novel_view_dataset, inputs=inputs, outputs=novel_view) with gr.Tab("Google Scanned Objects"): with gr.Row(): with gr.Column(scale=1): default_input_image = Image.open( os.path.join('online_demo', 'GSO', 'step-100_scale-6.0', 'SAMBA_HEMP', 'input.png')) default_input_image.thumbnail([256, 256], Image.Resampling.LANCZOS) input_image = gr.Image(default_input_image, shape=[256, 256]) options = sorted(os.listdir('online_demo/GSO/step-100_scale-6.0')) img_id = gr.Dropdown(options, value='SAMBA_HEMP', label='options') text_button = gr.Button("Load Input Image") retrieve_input_image_dataset = partial(retrieve_input_image, 'GSO') text_button.click(retrieve_input_image_dataset, inputs=img_id, outputs=input_image) with gr.Column(scale=1): novel_view = gr.Image(shape=[256, 256]) inputs = [img_id, gr.Slider(-30, 30, value=0, step=30, label='Polar angle (vertical rotation in degrees)'), gr.Slider(0, 330, value=0, step=30, label='Azimuth angle (horizontal rotation in degrees)'), gr.Slider(-0.5, 0.5, value=0, step=0.5, label='Zoom'), gr.Slider(0, 3, value=1, step=1, label='Random seed')] submit_button = gr.Button("Generate Novel View") retrieve_novel_view_dataset = partial(retrieve_novel_view, 'GSO') submit_button.click(retrieve_novel_view_dataset, inputs=inputs, outputs=novel_view) with gr.Tab("RTMV"): with gr.Row(): with gr.Column(scale=1): default_input_image = Image.open( os.path.join('online_demo', 'RTMV', 'step-100_scale-6.0', '00000', 'input.png')) default_input_image.thumbnail([256, 256], Image.Resampling.LANCZOS) input_image = gr.Image(default_input_image, shape=[256, 256]) options = sorted(os.listdir('online_demo/RTMV/step-100_scale-6.0')) img_id = gr.Dropdown(options, value='00000', label='options') text_button = gr.Button("Load Input Image") retrieve_input_image_dataset = partial(retrieve_input_image, 'RTMV') text_button.click(retrieve_input_image_dataset, inputs=img_id, outputs=input_image) with gr.Column(scale=1): novel_view = gr.Image(shape=[256, 256]) inputs = [img_id, gr.Slider(-30, 30, value=0, step=30, label='Polar angle (vertical rotation in degrees)'), gr.Slider(0, 330, value=0, step=30, label='Azimuth angle (horizontal rotation in degrees)'), gr.Slider(-0.5, 0.5, value=0, step=0.5, label='Zoom'), gr.Slider(0, 3, value=1, step=1, label='Random seed')] submit_button = gr.Button("Generate Novel View") retrieve_novel_view_dataset = partial(retrieve_novel_view, 'RTMV') submit_button.click(retrieve_novel_view_dataset, inputs=inputs, outputs=novel_view) if __name__ == "__main__": demo.launch()