Spaces:
Paused
Paused
| from functools import partial | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import os | |
| import fire | |
| from omegaconf import OmegaConf | |
| from SyncDreamer.ldm.models.diffusion.sync_dreamer import SyncDDIMSampler, SyncMultiviewDiffusion | |
| from SyncDreamer.ldm.util import add_margin, instantiate_from_config | |
| from sam_utils import sam_init, sam_out_nosave | |
| from SyncDreamer.ldm.util import instantiate_from_config, prepare_inputs | |
| import argparse | |
| import cv2 | |
| from transformers import pipeline | |
| from diffusers.utils import load_image, make_image_grid | |
| from diffusers import UniPCMultistepScheduler | |
| from pipeline_controlnet_sync import StableDiffusionControlNetPipeline | |
| from controlnet_sync import ControlNetModelSync | |
| _TITLE = '''ControlNet + SyncDreamer''' | |
| _DESCRIPTION = ''' | |
| Given a single-view image and select a target azimuth, ControlNet + SyncDreamer is able to generate the target view. | |
| This HF app is modified from [SyncDreamer HF app](https://huggingface.co/spaces/liuyuan-pal/SyncDreamer). The difference is that I added ControlNet on top of SyncDreamer. | |
| In addition, the elevations of both input and output images are assumed to be 30 degrees. | |
| ''' | |
| _USER_GUIDE0 = "Step1: Please upload an image in the block above (or choose an example shown in the left)." | |
| _USER_GUIDE2 = "Step2: Please choose a **Target azimuth** and click **Run Generation**. The **Target azimuth** is the azimuth of the output image relative to the input image in clockwise. This costs about 45s." | |
| _USER_GUIDE3 = "Generated output image of the target view is shown below! (You may adjust the **Crop size** and **Target azimuth** to get another result!)" | |
| others = '''**Step 1**. Select "Crop size" and click "Crop it". ==> The foreground object is centered and resized. </br>''' | |
| deployed = True | |
| if deployed: | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| class BackgroundRemoval: | |
| def __init__(self, device='cuda'): | |
| from carvekit.api.high import HiInterface | |
| self.interface = HiInterface( | |
| object_type="object", # Can be "object" or "hairs-like". | |
| batch_size_seg=5, | |
| batch_size_matting=1, | |
| device=device, | |
| seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net | |
| matting_mask_size=2048, | |
| trimap_prob_threshold=231, | |
| trimap_dilation=30, | |
| trimap_erosion_iters=5, | |
| fp16=True, | |
| ) | |
| def __call__(self, image): | |
| image = self.interface([image])[0] | |
| return image | |
| def resize_inputs(image_input, crop_size): | |
| if image_input is None: return None | |
| alpha_np = np.asarray(image_input)[:, :, 3] | |
| coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)] | |
| min_x, min_y = np.min(coords, 0) | |
| max_x, max_y = np.max(coords, 0) | |
| ref_img_ = image_input.crop((min_x, min_y, max_x, max_y)) | |
| h, w = ref_img_.height, ref_img_.width | |
| scale = crop_size / max(h, w) | |
| h_, w_ = int(scale * h), int(scale * w) | |
| ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC) | |
| results = add_margin(ref_img_, size=256) | |
| return results | |
| def generate(pipe, image_input, azimuth): | |
| target_index = round(azimuth % 360 / 22.5) | |
| output = pipe(conditioning_image=image_input) | |
| return output[target_index] | |
| def sam_predict(predictor, removal, raw_im): | |
| if raw_im is None: return None | |
| if deployed: | |
| raw_im.thumbnail([512, 512], Image.Resampling.LANCZOS) | |
| image_nobg = removal(raw_im.convert('RGB')) | |
| arr = np.asarray(image_nobg)[:, :, -1] | |
| x_nonzero = np.nonzero(arr.sum(axis=0)) | |
| y_nonzero = np.nonzero(arr.sum(axis=1)) | |
| x_min = int(x_nonzero[0].min()) | |
| y_min = int(y_nonzero[0].min()) | |
| x_max = int(x_nonzero[0].max()) | |
| y_max = int(y_nonzero[0].max()) | |
| image_nobg.thumbnail([512, 512], Image.Resampling.LANCZOS) | |
| image_sam = sam_out_nosave(predictor, image_nobg.convert("RGB"), (x_min, y_min, x_max, y_max)) | |
| image_sam = np.asarray(image_sam, np.float32) / 255 | |
| out_mask = image_sam[:, :, 3:] | |
| out_rgb = image_sam[:, :, :3] * out_mask + 1 - out_mask | |
| out_img = (np.concatenate([out_rgb, out_mask], 2) * 255).astype(np.uint8) | |
| image_sam = Image.fromarray(out_img, mode='RGBA') | |
| torch.cuda.empty_cache() | |
| return image_sam | |
| else: | |
| return raw_im | |
| def load_model(cfg,ckpt,strict=True): | |
| config = OmegaConf.load(cfg) | |
| model = instantiate_from_config(config.model) | |
| print(f'loading model from {ckpt} ...') | |
| ckpt = torch.load(ckpt,map_location='cuda') | |
| model.load_state_dict(ckpt['state_dict'],strict=strict) | |
| model = model.cuda().eval() | |
| return model | |
| def run_demo(): | |
| if deployed: | |
| controlnet = ControlNetModelSync.from_pretrained('controlnet_ckpt', torch_dtype=torch.float32, use_safetensors=True) | |
| cfg = 'SyncDreamer/configs/syncdreamer.yaml' | |
| dreamer = load_model(cfg, 'SyncDreamer/ckpt/syncdreamer-pretrain.ckpt', strict=True) | |
| controlnet.to('cuda', dtype=torch.float32) | |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
| controlnet=controlnet, dreamer=dreamer, torch_dtype=torch.float32, use_safetensors=True | |
| ) | |
| pipe.to('cuda', dtype=torch.float32) | |
| mask_predictor = sam_init() | |
| removal = BackgroundRemoval() | |
| else: | |
| mask_predictor = None | |
| removal = None | |
| controlnet = None | |
| dreamer = None | |
| pipe = None | |
| # NOTE: Examples must match inputs | |
| examples_full = [ | |
| ['hf_demo/examples/fox.png',200], | |
| ['hf_demo/examples/monkey.png',200], | |
| ['hf_demo/examples/cat.png',200], | |
| ['hf_demo/examples/crab.png',200], | |
| ['hf_demo/examples/elephant.png',200], | |
| ['hf_demo/examples/flower.png',200], | |
| ['hf_demo/examples/forest.png',200], | |
| ['hf_demo/examples/teapot.png',200], | |
| ['hf_demo/examples/basket.png',200], | |
| ] | |
| image_block = gr.Image(type='pil', image_mode='RGBA', height=256, label='Input image', tool=None, interactive=True) | |
| azimuth = gr.Slider(0, 360, 90, step=22.5, label='Target azimuth', interactive=True) | |
| crop_size = gr.Slider(120, 240, 200, step=10, label='Crop size', interactive=True) | |
| # Compose demo layout & data flow. | |
| with gr.Blocks(title=_TITLE, css="hf_demo/style.css") as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown('# ' + _TITLE) | |
| gr.Markdown(_DESCRIPTION) | |
| with gr.Row(variant='panel'): | |
| with gr.Column(scale=1.2): | |
| gr.Examples( | |
| examples=examples_full, # NOTE: elements must match inputs list! | |
| inputs=[image_block, crop_size], | |
| outputs=[image_block, crop_size], | |
| cache_examples=False, | |
| label='Examples (click one of the images below to start)', | |
| examples_per_page=5, | |
| ) | |
| with gr.Column(scale=0.8): | |
| image_block.render() | |
| guide_text = gr.Markdown(_USER_GUIDE0, visible=True) | |
| fig0 = gr.Image(value=Image.open('assets/crop_size.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False) | |
| with gr.Column(scale=0.8): | |
| sam_block = gr.Image(type='pil', image_mode='RGBA', label="SAM output", height=256, interactive=False) | |
| crop_size.render() | |
| fig1 = gr.Image(value=Image.open('assets/azimuth.jpg'), type='pil', image_mode='RGB', height=256, show_label=False, tool=None, interactive=False) | |
| with gr.Column(scale=0.8): | |
| input_block = gr.Image(type='pil', image_mode='RGBA', label="Input to ControlNet + SyncDreamer", height=256, interactive=False) | |
| azimuth.render() | |
| with gr.Accordion('Advanced options', open=False): | |
| seed = gr.Number(6033, label='Random seed', interactive=True) | |
| run_btn = gr.Button('Run generation', variant='primary', interactive=True) | |
| output_block = gr.Image(type='pil', image_mode='RGB', label="Output of ControlNet + SyncDreamer", height=256, interactive=False) | |
| def update_guide2(text, im): | |
| if im is None: | |
| return _USER_GUIDE0 | |
| else: | |
| return text | |
| update_guide = lambda GUIDE_TEXT: gr.update(value=GUIDE_TEXT) | |
| image_block.clear(fn=partial(update_guide, _USER_GUIDE0), outputs=[guide_text], queue=False) | |
| image_block.change(fn=partial(sam_predict, mask_predictor, removal), inputs=[image_block], outputs=[sam_block], queue=True) \ | |
| .success(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=True)\ | |
| .success(fn=partial(update_guide2, _USER_GUIDE2), inputs=[image_block], outputs=[guide_text], queue=False)\ | |
| crop_size.change(fn=resize_inputs, inputs=[sam_block, crop_size], outputs=[input_block], queue=True)\ | |
| .success(fn=partial(update_guide, _USER_GUIDE2), outputs=[guide_text], queue=False) | |
| run_btn.click(partial(generate, pipe), inputs=[input_block, azimuth], outputs=[output_block], queue=True)\ | |
| .success(fn=partial(update_guide, _USER_GUIDE3), outputs=[guide_text], queue=False) | |
| demo.queue().launch(share=False, max_threads=80) # auth=("admin", os.environ['PASSWD']) | |
| if __name__=="__main__": | |
| fire.Fire(run_demo) |