import spaces import os from typing import cast import gradio as gr import numpy as np import torch from PIL import Image from diffusers import DDIMScheduler from load_image import load_exr_image, load_ldr_image from pipeline_x2rgb import StableDiffusionAOVDropoutPipeline os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" current_directory = os.path.dirname(os.path.abspath(__file__)) _pipe = StableDiffusionAOVDropoutPipeline.from_pretrained( "zheng95z/x-to-rgb", torch_dtype=torch.float16, cache_dir=os.path.join(current_directory, "model_cache"), ).to("cuda") pipe = cast(StableDiffusionAOVDropoutPipeline, _pipe) pipe.scheduler = DDIMScheduler.from_config( pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_spacing="trailing" ) pipe.set_progress_bar_config(disable=True) pipe.to("cuda") pipe = cast(StableDiffusionAOVDropoutPipeline, pipe) @spaces.GPU def generate( albedo, normal, roughness, metallic, irradiance, prompt: str, seed: int, inference_step: int, num_samples: int, guidance_scale: float, image_guidance_scale: float, ) -> list[Image.Image]: generator = torch.Generator(device="cuda").manual_seed(seed) # Load and process each intrinsic channel image def process_image(file, **kwargs): if file is None: return None if file.name.endswith(".exr"): return load_exr_image(file.name, **kwargs).to("cuda") elif file.name.endswith((".png", ".jpg", ".jpeg")): return load_ldr_image(file.name, **kwargs).to("cuda") return None albedo_image = process_image(albedo, clamp=True) normal_image = process_image(normal, normalize=True) roughness_image = process_image(roughness, clamp=True) metallic_image = process_image(metallic, clamp=True) irradiance_image = process_image(irradiance, tonemaping=True, clamp=True) # Set default height and width based on the first available image height, width = 768, 768 for img in [ albedo_image, normal_image, roughness_image, metallic_image, irradiance_image, ]: if img is not None: height, width = img.shape[1], img.shape[2] break required_aovs = ["albedo", "normal", "roughness", "metallic", "irradiance"] return_list = [] for i in range(num_samples): generated_image = pipe( prompt=prompt, albedo=albedo_image, normal=normal_image, roughness=roughness_image, metallic=metallic_image, irradiance=irradiance_image, num_inference_steps=inference_step, height=height, width=width, generator=generator, required_aovs=required_aovs, guidance_scale=guidance_scale, image_guidance_scale=image_guidance_scale, guidance_rescale=0.7, output_type="np", ).images[0] # type: ignore return_list.append((generated_image, f"Generated Image {i}")) # Append additional images to the output gallery def post_process_image(img, **kwargs): if img is not None: return (img.cpu().numpy().transpose(1, 2, 0), kwargs.get("label", "Image")) return np.zeros((height, width, 3)) return_list.extend( [ post_process_image(albedo_image, label="Albedo"), post_process_image(normal_image, label="Normal"), post_process_image(roughness_image, label="Roughness"), post_process_image(metallic_image, label="Metallic"), post_process_image(irradiance_image, label="Irradiance"), ] ) return return_list with gr.Blocks() as demo: with gr.Row(): gr.Markdown("## Model X -> RGB (Intrinsic channels -> realistic image)") with gr.Row(): # Input side with gr.Column(): gr.Markdown("### Given intrinsic channels") albedo = gr.File(label="Albedo", file_types=[".exr", ".png", ".jpg"]) normal = gr.File(label="Normal", file_types=[".exr", ".png", ".jpg"]) roughness = gr.File(label="Roughness", file_types=[".exr", ".png", ".jpg"]) metallic = gr.File(label="Metallic", file_types=[".exr", ".png", ".jpg"]) irradiance = gr.File( label="Irradiance", file_types=[".exr", ".png", ".jpg"] ) gr.Markdown("### Parameters") prompt = gr.Textbox(label="Prompt") run_button = gr.Button(value="Run") with gr.Accordion("Advanced options", open=False): seed = gr.Slider( label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True, ) inference_step = gr.Slider( label="Inference Step", minimum=1, maximum=100, step=1, value=50, ) num_samples = gr.Slider( label="Samples", minimum=1, maximum=100, step=1, value=1, ) guidance_scale = gr.Slider( label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=7.5, ) image_guidance_scale = gr.Slider( label="Image Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=1.5, ) # Output side with gr.Column(): gr.Markdown("### Output Gallery") result_gallery = gr.Gallery( label="Output", show_label=False, elem_id="gallery", columns=2, ) run_button.click( fn=generate, inputs=[ albedo, normal, roughness, metallic, irradiance, prompt, seed, inference_step, num_samples, guidance_scale, image_guidance_scale, ], outputs=result_gallery, queue=True, ) if __name__ == "__main__": demo.launch(debug=False, share=False, show_api=False)