''' conda activate zero123 cd zero123 python gradio_new.py 0 ''' import diffusers # 0.12.1 import math import fire import gradio as gr import lovely_numpy import lovely_tensors import numpy as np import os import plotly.express as px import plotly.graph_objects as go import rich import sys import time import torch from contextlib import nullcontext from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from einops import rearrange from functools import partial from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import create_carvekit_interface, load_and_preprocess, instantiate_from_config from lovely_numpy import lo from omegaconf import OmegaConf from PIL import Image from rich import print from transformers import AutoFeatureExtractor from torch import autocast from torchvision import transforms _SHOW_DESC = True _SHOW_INTERMEDIATE = False # _SHOW_INTERMEDIATE = True _GPU_INDEX = 0 # _GPU_INDEX = 2 # _TITLE = 'Zero-Shot Control of Camera Viewpoints within a Single Image' _TITLE = 'Zero-1-to-3: Zero-shot One Image to 3D Object' # This demo allows you to generate novel viewpoints of an object depicted in an input image using a fine-tuned version of Stable Diffusion. _DESCRIPTION = ''' This live demo allows you to control camera rotation and thereby generate novel viewpoints of an object within a single image. It is based on Stable Diffusion. Check out our [project webpage](https://zero123.cs.columbia.edu/) and [paper](https://arxiv.org/pdf/2303.11328.pdf) if you want to learn more about the method! Note that this model is not intended for images of humans or faces, and is unlikely to work well for them. ''' _ARTICLE = 'See uses.md' def load_model_from_config(config, ckpt, device, verbose=False): print(f'Loading model from {ckpt}') pl_sd = torch.load(ckpt, map_location='cpu') if 'global_step' in pl_sd: print(f'Global Step: {pl_sd["global_step"]}') sd = pl_sd['state_dict'] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print('missing keys:') print(m) if len(u) > 0 and verbose: print('unexpected keys:') print(u) model.to(device) model.eval() return model @torch.no_grad() def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, ddim_eta, x, y, z): precision_scope = autocast if precision == 'autocast' else nullcontext with precision_scope('cuda'): with model.ema_scope(): c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1) T = torch.tensor([math.radians(x), math.sin( math.radians(y)), math.cos(math.radians(y)), z]) T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device) c = torch.cat([c, T], dim=-1) c = model.cc_projection(c) cond = {} cond['c_crossattn'] = [c] c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach() cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach() .repeat(n_samples, 1, 1, 1)] if scale != 1.0: uc = {} uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)] uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] else: uc = None shape = [4, h // 8, w // 8] samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=cond, batch_size=n_samples, shape=shape, verbose=False, unconditional_guidance_scale=scale, unconditional_conditioning=uc, eta=ddim_eta, x_T=None) print(samples_ddim.shape) # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False) x_samples_ddim = model.decode_first_stage(samples_ddim) return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu() class CameraVisualizer: def __init__(self, gradio_plot): self._gradio_plot = gradio_plot self._fig = None self._polar = 0.0 self._azimuth = 0.0 self._radius = 0.0 self._raw_image = None self._8bit_image = None self._image_colorscale = None def polar_change(self, value): self._polar = value # return self.update_figure() def azimuth_change(self, value): self._azimuth = value # return self.update_figure() def radius_change(self, value): self._radius = value # return self.update_figure() def encode_image(self, raw_image): ''' :param raw_image (H, W, 3) array of uint8 in [0, 255]. ''' # https://stackoverflow.com/questions/60685749/python-plotly-how-to-add-an-image-to-a-3d-scatter-plot dum_img = Image.fromarray(np.ones((3, 3, 3), dtype='uint8')).convert('P', palette='WEB') idx_to_color = np.array(dum_img.getpalette()).reshape((-1, 3)) self._raw_image = raw_image self._8bit_image = Image.fromarray(raw_image).convert('P', palette='WEB', dither=None) # self._8bit_image = Image.fromarray(raw_image.clip(0, 254)).convert( # 'P', palette='WEB', dither=None) self._image_colorscale = [ [i / 255.0, 'rgb({}, {}, {})'.format(*rgb)] for i, rgb in enumerate(idx_to_color)] # return self.update_figure() def update_figure(self): fig = go.Figure() if self._raw_image is not None: (H, W, C) = self._raw_image.shape x = np.zeros((H, W)) (y, z) = np.meshgrid(np.linspace(-1.0, 1.0, W), np.linspace(1.0, -1.0, H) * H / W) print('x:', lo(x)) print('y:', lo(y)) print('z:', lo(z)) fig.add_trace(go.Surface( x=x, y=y, z=z, surfacecolor=self._8bit_image, cmin=0, cmax=255, colorscale=self._image_colorscale, showscale=False, lighting_diffuse=1.0, lighting_ambient=1.0, lighting_fresnel=1.0, lighting_roughness=1.0, lighting_specular=0.3)) scene_bounds = 3.5 base_radius = 2.5 zoom_scale = 1.5 # Note that input radius offset is in [-0.5, 0.5]. fov_deg = 50.0 edges = [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4), (4, 1)] input_cone = calc_cam_cone_pts_3d( 0.0, 0.0, base_radius, fov_deg) # (5, 3). output_cone = calc_cam_cone_pts_3d( self._polar, self._azimuth, base_radius + self._radius * zoom_scale, fov_deg) # (5, 3). # print('input_cone:', lo(input_cone).v) # print('output_cone:', lo(output_cone).v) for (cone, clr, legend) in [(input_cone, 'green', 'Input view'), (output_cone, 'blue', 'Target view')]: for (i, edge) in enumerate(edges): (x1, x2) = (cone[edge[0], 0], cone[edge[1], 0]) (y1, y2) = (cone[edge[0], 1], cone[edge[1], 1]) (z1, z2) = (cone[edge[0], 2], cone[edge[1], 2]) fig.add_trace(go.Scatter3d( x=[x1, x2], y=[y1, y2], z=[z1, z2], mode='lines', line=dict(color=clr, width=3), name=legend, showlegend=(i == 0))) # text=(legend if i == 0 else None), # textposition='bottom center')) # hoverinfo='text', # hovertext='hovertext')) # Add label. if cone[0, 2] <= base_radius / 2.0: fig.add_trace(go.Scatter3d( x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] - 0.05], showlegend=False, mode='text', text=legend, textposition='bottom center')) else: fig.add_trace(go.Scatter3d( x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] + 0.05], showlegend=False, mode='text', text=legend, textposition='top center')) # look at center of scene fig.update_layout( # width=640, # height=480, # height=400, height=360, autosize=True, hovermode=False, margin=go.layout.Margin(l=0, r=0, b=0, t=0), showlegend=True, legend=dict( yanchor='bottom', y=0.01, xanchor='right', x=0.99, ), scene=dict( aspectmode='manual', aspectratio=dict(x=1, y=1, z=1.0), camera=dict( eye=dict(x=base_radius - 1.6, y=0.0, z=0.6), center=dict(x=0.0, y=0.0, z=0.0), up=dict(x=0.0, y=0.0, z=1.0)), xaxis_title='', yaxis_title='', zaxis_title='', xaxis=dict( range=[-scene_bounds, scene_bounds], showticklabels=False, showgrid=True, zeroline=False, showbackground=True, showspikes=False, showline=False, ticks=''), yaxis=dict( range=[-scene_bounds, scene_bounds], showticklabels=False, showgrid=True, zeroline=False, showbackground=True, showspikes=False, showline=False, ticks=''), zaxis=dict( range=[-scene_bounds, scene_bounds], showticklabels=False, showgrid=True, zeroline=False, showbackground=True, showspikes=False, showline=False, ticks=''))) self._fig = fig return fig def preprocess_image(models, input_im, preprocess): ''' :param input_im (PIL Image). :return input_im (H, W, 3) array in [0, 1]. ''' print('old input_im:', input_im.size) start_time = time.time() if preprocess: input_im = load_and_preprocess(models['carvekit'], input_im) input_im = (input_im / 255.0).astype(np.float32) # (H, W, 3) array in [0, 1]. else: input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS) input_im = np.asarray(input_im, dtype=np.float32) / 255.0 # (H, W, 4) array in [0, 1]. # old method: thresholding background, very important # input_im[input_im[:, :, -1] <= 0.9] = [1., 1., 1., 1.] # new method: apply correct method of compositing to avoid sudden transitions / thresholding # (smoothly transition foreground to white background based on alpha values) alpha = input_im[:, :, 3:4] white_im = np.ones_like(input_im) input_im = alpha * input_im + (1.0 - alpha) * white_im input_im = input_im[:, :, 0:3] # (H, W, 3) array in [0, 1]. print(f'Infer foreground mask (preprocess_image) took {time.time() - start_time:.3f}s.') print('new input_im:', lo(input_im)) return input_im def main_run(models, device, cam_vis, return_what, x=0.0, y=0.0, z=0.0, raw_im=None, preprocess=True, scale=3.0, n_samples=4, ddim_steps=50, ddim_eta=1.0, precision='fp32', h=256, w=256): ''' :param raw_im (PIL Image). ''' raw_im.thumbnail([1536, 1536], Image.Resampling.LANCZOS) safety_checker_input = models['clip_fe'](raw_im, return_tensors='pt').to(device) (image, has_nsfw_concept) = models['nsfw']( images=np.ones((1, 3)), clip_input=safety_checker_input.pixel_values) print('has_nsfw_concept:', has_nsfw_concept) if np.any(has_nsfw_concept): print('NSFW content detected.') to_return = [None] * 10 description = ('### Unfortunately, ' 'potential NSFW content was detected, ' 'which is not supported by our model. ' 'Please try again with a different image. ') if 'angles' in return_what: to_return[0] = 0.0 to_return[1] = 0.0 to_return[2] = 0.0 to_return[3] = description else: to_return[0] = description return to_return else: print('Safety check passed.') input_im = preprocess_image(models, raw_im, preprocess) # if np.random.rand() < 0.3: # description = ('Unfortunately, a human, a face, or potential NSFW content was detected, ' # 'which is not supported by our model.') # if vis_only: # return (None, None, description) # else: # return (None, None, None, description) show_in_im1 = (input_im * 255.0).astype(np.uint8) show_in_im2 = Image.fromarray(show_in_im1) if 'rand' in return_what: x = int(np.round(np.arcsin(np.random.uniform(-1.0, 1.0)) * 160.0 / np.pi)) # [-80, 80]. y = int(np.round(np.random.uniform(-150.0, 150.0))) z = 0.0 cam_vis.polar_change(x) cam_vis.azimuth_change(y) cam_vis.radius_change(z) cam_vis.encode_image(show_in_im1) new_fig = cam_vis.update_figure() if 'vis' in return_what: description = ('The viewpoints are visualized on the top right. ' 'Click Run Generation to update the results on the bottom right.') if 'angles' in return_what: return (x, y, z, description, new_fig, show_in_im2) else: return (description, new_fig, show_in_im2) elif 'gen' in return_what: input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device) input_im = input_im * 2 - 1 input_im = transforms.functional.resize(input_im, [h, w]) sampler = DDIMSampler(models['turncam']) # used_x = -x # NOTE: Polar makes more sense in Basile's opinion this way! used_x = x # NOTE: Set this way for consistency. x_samples_ddim = sample_model(input_im, models['turncam'], sampler, precision, h, w, ddim_steps, n_samples, scale, ddim_eta, used_x, y, z) output_ims = [] for x_sample in x_samples_ddim: x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') output_ims.append(Image.fromarray(x_sample.astype(np.uint8))) description = None if 'angles' in return_what: return (x, y, z, description, new_fig, show_in_im2, output_ims) else: return (description, new_fig, show_in_im2, output_ims) def calc_cam_cone_pts_3d(polar_deg, azimuth_deg, radius_m, fov_deg): ''' :param polar_deg (float). :param azimuth_deg (float). :param radius_m (float). :param fov_deg (float). :return (5, 3) array of float with (x, y, z). ''' polar_rad = np.deg2rad(polar_deg) azimuth_rad = np.deg2rad(azimuth_deg) fov_rad = np.deg2rad(fov_deg) polar_rad = -polar_rad # NOTE: Inverse of how used_x relates to x. # Camera pose center: cam_x = radius_m * np.cos(azimuth_rad) * np.cos(polar_rad) cam_y = radius_m * np.sin(azimuth_rad) * np.cos(polar_rad) cam_z = radius_m * np.sin(polar_rad) # Obtain four corners of camera frustum, assuming it is looking at origin. # First, obtain camera extrinsics (rotation matrix only): camera_R = np.array([[np.cos(azimuth_rad) * np.cos(polar_rad), -np.sin(azimuth_rad), -np.cos(azimuth_rad) * np.sin(polar_rad)], [np.sin(azimuth_rad) * np.cos(polar_rad), np.cos(azimuth_rad), -np.sin(azimuth_rad) * np.sin(polar_rad)], [np.sin(polar_rad), 0.0, np.cos(polar_rad)]]) # print('camera_R:', lo(camera_R).v) # Multiply by corners in camera space to obtain go to space: corn1 = [-1.0, np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0)] corn2 = [-1.0, -np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0)] corn3 = [-1.0, -np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0)] corn4 = [-1.0, np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0)] corn1 = np.dot(camera_R, corn1) corn2 = np.dot(camera_R, corn2) corn3 = np.dot(camera_R, corn3) corn4 = np.dot(camera_R, corn4) # Now attach as offset to actual 3D camera position: corn1 = np.array(corn1) / np.linalg.norm(corn1, ord=2) corn_x1 = cam_x + corn1[0] corn_y1 = cam_y + corn1[1] corn_z1 = cam_z + corn1[2] corn2 = np.array(corn2) / np.linalg.norm(corn2, ord=2) corn_x2 = cam_x + corn2[0] corn_y2 = cam_y + corn2[1] corn_z2 = cam_z + corn2[2] corn3 = np.array(corn3) / np.linalg.norm(corn3, ord=2) corn_x3 = cam_x + corn3[0] corn_y3 = cam_y + corn3[1] corn_z3 = cam_z + corn3[2] corn4 = np.array(corn4) / np.linalg.norm(corn4, ord=2) corn_x4 = cam_x + corn4[0] corn_y4 = cam_y + corn4[1] corn_z4 = cam_z + corn4[2] xs = [cam_x, corn_x1, corn_x2, corn_x3, corn_x4] ys = [cam_y, corn_y1, corn_y2, corn_y3, corn_y4] zs = [cam_z, corn_z1, corn_z2, corn_z3, corn_z4] return np.array([xs, ys, zs]).T def run_demo( device_idx=_GPU_INDEX, ckpt='105000.ckpt', config='configs/sd-objaverse-finetune-c_concat-256.yaml'): print('sys.argv:', sys.argv) if len(sys.argv) > 1: print('old device_idx:', device_idx) device_idx = int(sys.argv[1]) print('new device_idx:', device_idx) device = f'cuda:{device_idx}' config = OmegaConf.load(config) # Instantiate all models beforehand for efficiency. models = dict() print('Instantiating LatentDiffusion...') models['turncam'] = load_model_from_config(config, ckpt, device=device) print('Instantiating Carvekit HiInterface...') models['carvekit'] = create_carvekit_interface() print('Instantiating StableDiffusionSafetyChecker...') models['nsfw'] = StableDiffusionSafetyChecker.from_pretrained( 'CompVis/stable-diffusion-safety-checker').to(device) print('Instantiating AutoFeatureExtractor...') models['clip_fe'] = AutoFeatureExtractor.from_pretrained( 'CompVis/stable-diffusion-safety-checker') # Reduce NSFW false positives. # NOTE: At the time of writing, and for diffusers 0.12.1, the default parameters are: # models['nsfw'].concept_embeds_weights: # [0.1800, 0.1900, 0.2060, 0.2100, 0.1950, 0.1900, 0.1940, 0.1900, 0.1900, 0.2200, 0.1900, # 0.1900, 0.1950, 0.1984, 0.2100, 0.2140, 0.2000]. # models['nsfw'].special_care_embeds_weights: # [0.1950, 0.2000, 0.2200]. # We multiply all by some factor > 1 to make them less likely to be triggered. models['nsfw'].concept_embeds_weights *= 1.07 models['nsfw'].special_care_embeds_weights *= 1.07 with open('instructions.md', 'r') as f: article = f.read() # NOTE: Examples must match inputs # [polar_slider, azimuth_slider, radius_slider, image_block, # preprocess_chk, scale_slider, samples_slider, steps_slider]. example_fns = ['1_blue_arm.png', '2_cybercar.png', '3_sushi.png', '4_blackarm.png', '5_cybercar.png', '6_burger.png', '7_london.png', '8_motor.png'] num_examples = len(example_fns) example_fps = [os.path.join(os.path.dirname(__file__), 'configs', x) for x in example_fns] example_angles = [(-40.0, -65.0, 0.0), (-30.0, 90.0, 0.0), (45.0, -15.0, 0.0), (-75.0, 100.0, 0.0), (-40.0, -75.0, 0.0), (-45.0, 0.0, 0.0), (-55.0, 90.0, 0.0), (-20.0, 125.0, 0.0)] examples_full = [[*example_angles[i], example_fps[i], True, 3, 4, 50] for i in range(num_examples)] print('examples_full:', examples_full) # Compose demo layout & data flow. demo = gr.Blocks(title=_TITLE) with demo: gr.Markdown('# ' + _TITLE) gr.Markdown(_DESCRIPTION) with gr.Row(): with gr.Column(scale=0.9, variant='panel'): image_block = gr.Image(type='pil', image_mode='RGBA', label='Input image of single object') preprocess_chk = gr.Checkbox( True, label='Preprocess image automatically (remove background and recenter object)') # info='If enabled, the uploaded image will be preprocessed to remove the background and recenter the object by cropping and/or padding as necessary. ' # 'If disabled, the image will be used as-is, *BUT* a fully transparent or white background is required.'), gr.Markdown('*Try camera position presets:*') with gr.Row(): left_btn = gr.Button('View from the Left', variant='primary') above_btn = gr.Button('View from Above', variant='primary') right_btn = gr.Button('View from the Right', variant='primary') with gr.Row(): random_btn = gr.Button('Random Rotation', variant='primary') below_btn = gr.Button('View from Below', variant='primary') behind_btn = gr.Button('View from Behind', variant='primary') gr.Markdown('*Control camera position manually:*') polar_slider = gr.Slider( -90, 90, value=0, step=5, label='Polar angle (vertical rotation in degrees)') # info='Positive values move the camera down, while negative values move the camera up.') azimuth_slider = gr.Slider( -180, 180, value=0, step=5, label='Azimuth angle (horizontal rotation in degrees)') # info='Positive values move the camera right, while negative values move the camera left.') radius_slider = gr.Slider( -0.5, 0.5, value=0.0, step=0.1, label='Zoom (relative distance from center)') # info='Positive values move the camera further away, while negative values move the camera closer.') samples_slider = gr.Slider(1, 8, value=4, step=1, label='Number of samples to generate') with gr.Accordion('Advanced options', open=False): scale_slider = gr.Slider(0, 30, value=3, step=1, label='Diffusion guidance scale') steps_slider = gr.Slider(5, 200, value=75, step=5, label='Number of diffusion inference steps') with gr.Row(): vis_btn = gr.Button('Visualize Angles', variant='secondary') run_btn = gr.Button('Run Generation', variant='primary') desc_output = gr.Markdown( 'The results will appear on the right.', visible=_SHOW_DESC) with gr.Column(scale=1.1, variant='panel'): vis_output = gr.Plot( label='Relationship between input (green) and output (blue) camera poses') gen_output = gr.Gallery(label='Generated images from specified new viewpoint') gen_output.style(grid=2) preproc_output = gr.Image(type='pil', image_mode='RGB', label='Preprocessed input image', visible=_SHOW_INTERMEDIATE) cam_vis = CameraVisualizer(vis_output) gr.Examples( examples=examples_full, # NOTE: elements must match inputs list! fn=partial(main_run, models, device, cam_vis, 'gen'), inputs=[polar_slider, azimuth_slider, radius_slider, image_block, preprocess_chk, scale_slider, samples_slider, steps_slider], outputs=[desc_output, vis_output, preproc_output, gen_output], cache_examples=True, run_on_click=True, ) gr.Markdown(article) # NOTE: I am forced to update vis_output for these preset buttons, # because otherwise the gradio plot always resets the plotly 3D viewpoint for some reason, # which might confuse the user into thinking that the plot has been updated too. # polar_slider.change(fn=partial(main_run, models, device, cam_vis, 'vis'), # inputs=[polar_slider, azimuth_slider, radius_slider, # image_block, preprocess_chk], # outputs=[desc_output, vis_output, preproc_output], # queue=False) # azimuth_slider.change(fn=partial(main_run, models, device, cam_vis, 'vis'), # inputs=[polar_slider, azimuth_slider, radius_slider, # image_block, preprocess_chk], # outputs=[desc_output, vis_output, preproc_output], # queue=False) # radius_slider.change(fn=partial(main_run, models, device, cam_vis, 'vis'), # inputs=[polar_slider, azimuth_slider, radius_slider, # image_block, preprocess_chk], # outputs=[desc_output, vis_output, preproc_output], # queue=False) vis_btn.click(fn=partial(main_run, models, device, cam_vis, 'vis'), inputs=[polar_slider, azimuth_slider, radius_slider, image_block, preprocess_chk], outputs=[desc_output, vis_output, preproc_output], queue=False) run_btn.click(fn=partial(main_run, models, device, cam_vis, 'gen'), inputs=[polar_slider, azimuth_slider, radius_slider, image_block, preprocess_chk, scale_slider, samples_slider, steps_slider], outputs=[desc_output, vis_output, preproc_output, gen_output]) # NEW: preset_inputs = [image_block, preprocess_chk, scale_slider, samples_slider, steps_slider] preset_outputs = [polar_slider, azimuth_slider, radius_slider, desc_output, vis_output, preproc_output, gen_output] left_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen', 0.0, -90.0, 0.0), inputs=preset_inputs, outputs=preset_outputs) above_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen', -90.0, 0.0, 0.0), inputs=preset_inputs, outputs=preset_outputs) right_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen', 0.0, 90.0, 0.0), inputs=preset_inputs, outputs=preset_outputs) random_btn.click(fn=partial(main_run, models, device, cam_vis, 'rand_angles_gen', -1.0, -1.0, -1.0), inputs=preset_inputs, outputs=preset_outputs) below_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen', 90.0, 0.0, 0.0), inputs=preset_inputs, outputs=preset_outputs) behind_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen', 0.0, 180.0, 0.0), inputs=preset_inputs, outputs=preset_outputs) demo.launch(enable_queue=True) if __name__ == '__main__': fire.Fire(run_demo)