import spaces import os from typing import cast import gradio as gr from PIL import Image import torch import torchvision from diffusers import DDIMScheduler from load_image import load_exr_image, load_ldr_image from pipeline_rgb2x import StableDiffusionAOVMatEstPipeline os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" current_directory = os.path.dirname(os.path.abspath(__file__)) _pipe = StableDiffusionAOVMatEstPipeline.from_pretrained( "zheng95z/rgb-to-x", torch_dtype=torch.float16, cache_dir=os.path.join(current_directory, "model_cache"), ).to("cuda") pipe = cast(StableDiffusionAOVMatEstPipeline, _pipe) pipe.scheduler = DDIMScheduler.from_config( pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" ) pipe.set_progress_bar_config(disable=True) pipe.to("cuda") pipe = cast(StableDiffusionAOVMatEstPipeline, pipe) @spaces.GPU def generate( photo, seed: int, inference_step: int, num_samples: int, ) -> list[Image.Image]: generator = torch.Generator(device="cuda").manual_seed(seed) if photo.name.endswith(".exr"): photo = load_exr_image(photo.name, tonemaping=True, clamp=True).to("cuda") elif ( photo.name.endswith(".png") or photo.name.endswith(".jpg") or photo.name.endswith(".jpeg") ): photo = load_ldr_image(photo.name, from_srgb=True).to("cuda") # Check if the width and height are multiples of 8. If not, crop it using torchvision.transforms.CenterCrop old_height = photo.shape[1] old_width = photo.shape[2] new_height = old_height new_width = old_width radio = old_height / old_width max_side = 1000 if old_height > old_width: new_height = max_side new_width = int(new_height / radio) else: new_width = max_side new_height = int(new_width * radio) if new_width % 8 != 0 or new_height % 8 != 0: new_width = new_width // 8 * 8 new_height = new_height // 8 * 8 photo = torchvision.transforms.Resize((new_height, new_width))(photo) required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] prompts = { "albedo": "Albedo (diffuse basecolor)", "normal": "Camera-space Normal", "roughness": "Roughness", "metallic": "Metallicness", "irradiance": "Irradiance (diffuse lighting)", } return_list = [] for i in range(num_samples): for aov_name in required_aovs: prompt = prompts[aov_name] generated_image = pipe( prompt=prompt, photo=photo, num_inference_steps=inference_step, height=new_height, width=new_width, generator=generator, required_aovs=[aov_name], ).images[0][0] # type: ignore generated_image = torchvision.transforms.Resize((old_height, old_width))( generated_image ) generated_image = (generated_image, f"Generated {aov_name} {i}") return_list.append(generated_image) return return_list with gr.Blocks() as demo: with gr.Row(): gr.Markdown("## Model RGB -> X (Realistic image -> Intrinsic channels)") with gr.Row(): # Input side with gr.Column(): gr.Markdown("### Given Image") photo = gr.File(label="Photo", file_types=[".exr", ".png", ".jpg"]) gr.Markdown("### Parameters") run_button = gr.Button(value="Run") with gr.Accordion("Advanced options", open=False): seed = gr.Slider( label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True, ) inference_step = gr.Slider( label="Inference Step", minimum=1, maximum=100, step=1, value=50, ) num_samples = gr.Slider( label="Samples", minimum=1, maximum=100, step=1, value=1, ) # Output side with gr.Column(): gr.Markdown("### Output Gallery") result_gallery = gr.Gallery( label="Output", show_label=False, elem_id="gallery", columns=2, ) examples = gr.Examples( examples=[ [ "rgb2x/example/Castlereagh_corridor_photo.png", ] ], inputs=[photo], outputs=[result_gallery], fn=generate, cache_mode="eager", cache_examples=True, ) run_button.click( fn=generate, inputs=[photo, seed, inference_step, num_samples], outputs=result_gallery, queue=True, ) if __name__ == "__main__": demo.launch(debug=False, share=False, show_api=False)