import spaces import gradio as gr import torch from PIL import Image from torchvision import transforms # from diffusers import StableDiffusionImageVariationPipeline from inference import InferenceModel from pytorch_lightning import seed_everything import numpy as np import os import rembg import sys from loguru import logger _SAMPLE_TAB_ID_ = 0 _HIGHRES_TAB_ID_ = 1 _FOREGROUND_TAB_ID_ = 2 def set_loggers(level): logger.remove() logger.add(sys.stderr, level=level) def on_guide_select(evt: gr.SelectData): logger.debug(f"You selected {evt.value} at {evt.index} from {evt.target}") return [evt.value["image"]['path'], f"Sample {evt.index}"] def on_input_select(evt: gr.SelectData): logger.debug(f"You selected {evt.value} at {evt.index} from {evt.target}") return evt.value["image"]['path'] @spaces.GPU(duration=120) def sample_fine( input_im, domain="Albedo", require_mask=False, steps=25, n_samples=4, seed=0, guid_img=None, vert_split=2, hor_split=2, overlaps=2, guidance_scale=2, ): if require_mask: input_im = remove_bg(input_im) seed_everything(int(seed)) model = model_dict[domain] inp = tform(input_im).to(device).permute(1,2,0) guid_img = tform(guid_img).to(device).permute(1,2,0) images = model.generation((vert_split, hor_split), overlaps, guid_img[..., :3], inp[..., :3], inp[..., 3:], dps_scale=guidance_scale, uc_score=1.0, ddim_steps=steps, batch_size=1, n_samples=1) images["guid_iamges"] = [(guid_img.detach().cpu().numpy() * 255).astype(np.uint8)] output = images["out_images"][0] return [[(output, "High-res")], gr.Tabs(selected=_HIGHRES_TAB_ID_)] def remove_bg(input_im): output = rembg.remove(input_im, session=model_dict["remove_bg"]) return output @spaces.GPU() def sampling(input_im, domain="Albedo", require_mask=False, steps=25, n_samples=4, seed=0): seed_everything(int(seed)) model = model_dict[domain] if require_mask: input_im = remove_bg(input_im) inp = tform(input_im).to(device).permute(1,2,0) images = model.generation((1, 1), 1, None, inp[..., :3], inp[..., 3:], dps_scale=0, uc_score=1, ddim_steps=steps, batch_size=1, n_samples=n_samples) output = [[(images["input_image"][0], "Foreground Object"), (images["input_maskes"][0], "Foreground Maks")], [(img,f"Sample {idx}") for idx, img in enumerate(images["out_images"])], gr.Tabs(selected=_SAMPLE_TAB_ID_), ] return output title = "IntrinsicAnything: Learning Diffusion Priors for Inverse Rendering Under Unknown Illumination" description = \ """ #### Generate intrinsic images (Albedo, Specular Shading) from a single image. ##### Tips - You can check the "Auto Mask" box if the input image requires a foreground mask. Or supply your mask with RGBA input. - You can optionally generate a high-resolution sample if the input image is of high resolution. We split the original image into `Vertical Splits` by `Horizontal Splits` patches with some `Overlaps` in between. Due to computation constraints for the online demo, we recommend `Vertical Splits` x `Horizontal Splits` to be no more than 6 and to set 2 for `Overlaps`. The denoising steps should at least be set to 80 for high resolution samples. """ set_loggers("INFO") device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Loading Models...") model_dict = { "Albedo": InferenceModel(ckpt_path="weights/albedo", use_ddim=True, gpu_id=0), "Specular": InferenceModel(ckpt_path="weights/specular", use_ddim=True, gpu_id=0), "remove_bg": rembg.new_session(), } logger.info(f"All models Loaded!") tform = transforms.Compose([ transforms.ToTensor() ]) examples_dir = "examples" examples = [[os.path.join(examples_dir, img_name)] for img_name in os.listdir(examples_dir)] # theme definition theme = gr.Theme.from_hub("NoCrypt/miku") theme.body_background_fill = "#FFFFFF " theme.body_background_fill_dark = "#000000" demo = gr.Blocks(title=title, theme=theme) with demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown('# ' + title) gr.Markdown(description) with gr.Column(): with gr.Row(): with gr.Column(scale=0.8): image_input = [gr.Image(image_mode='RGBA', height=256)] with gr.Column(): with gr.Tabs(): with gr.TabItem("Options"): with gr.Column(): with gr.Row(): domain_box = gr.Radio([("Albedo", "Albedo"),("Specular", "Specular")], value="Albedo", label="Type") with gr.Column(): gr.Markdown("### Automatic foreground segmentation") mask_box = gr.Checkbox(False, label="Auto Mask") options_tab = [ domain_box, mask_box, gr.Slider(5, 200, value=50, step=5, label="Denoising Steps (The larger the better results)"), gr.Slider(1, 10, value=2, step=1, label="Number of Samples"), gr.Number(75424, label="Seed", precision=0), ] with gr.TabItem("Advanced (High-res)"): with gr.Column(): guiding_img = gr.Image(image_mode='RGBA', label="Guiding Image", interactive=False, height=256, visible=False) sample_idx = gr.Textbox(placeholder="Select one from the generate low-res samples", lines=1, interactive=False, label="Guiding Image") options_advanced_tab = [ # high resolution options guiding_img, gr.Slider(1, 4, value=2, step=1, label="Vertical Splits"), gr.Slider(1, 4, value=2, step=1, label="Horizontal Splits"), gr.Slider(1, 5, value=2, step=1, label="Overlaps"), gr.Slider(0, 5, value=3, step=1, label="Guidance Scale"),] with gr.Column(scale=1.0): with gr.Tabs() as res_tabs: with gr.TabItem("Generated Samples", id=_SAMPLE_TAB_ID_): image_output = gr.Gallery(label="Generated Samples", object_fit="contain", columns=[2], rows=[2],height=512, selected_index=0) with gr.TabItem("High Resolution Sample", id=_HIGHRES_TAB_ID_): image_output_high = gr.Gallery(label="High Resolution Sample", object_fit="contain", columns=[1], rows=[1],height=512, selected_index=0) with gr.TabItem("Foreground Object", id=_FOREGROUND_TAB_ID_): forground_output = gr.Gallery(label="Foreground Object", object_fit="contain", columns=[2], rows=[1],height=512, selected_index=0) with gr.Row(): generate_button = gr.Button("Generate") generate_button_fine = gr.Button("Generate High-Res") examples_gr = gr.Examples(examples=examples, inputs=image_input, cache_examples=False, examples_per_page=30, label='Examples (Click one to start!)') with gr.Row(): pass # forground_output = gr.Gallery(label="Inputs", preview=False, columns=[2], rows=[1],height=512, selected_index=0) # image_output = gr.Gallery(label="Generated Samples", object_fit="cover", columns=[1], rows=[6],height=512, selected_index=0) # image_output_high = gr.Gallery(label="High Resolution Sample", object_fit="cover", columns=[1], rows=[1],height=512, selected_index=0) generate_button.click(sampling, inputs=image_input+options_tab, outputs=[forground_output, image_output, res_tabs]) generate_button_fine.click(sample_fine, inputs=image_input+options_tab+options_advanced_tab, outputs=[image_output_high, res_tabs]) image_output.select(on_guide_select, None, [guiding_img, sample_idx]) logger.info(f"Demo Initilized, Starting...") demo.queue().launch()