Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| import torch | |
| from diffusers import AutoencoderKL, ControlNetModel, TCDScheduler | |
| from gradio_imageslider import ImageSlider | |
| from image_gen_aux import LineArtPreprocessor | |
| from PIL import Image, ImageEnhance | |
| from controlnet_union import ControlNetModel_Union | |
| from pipeline_sdxl_recolor import StableDiffusionXLRecolorPipeline | |
| lineart_preprocessor = LineArtPreprocessor.from_pretrained("OzzyGT/lineart").to("cuda") | |
| controlnet = [ | |
| ControlNetModel.from_pretrained( | |
| "OzzyGT/ControlNet-recolorXL", torch_dtype=torch.float16, variant="fp16" | |
| ), | |
| ControlNetModel_Union.from_pretrained( | |
| "OzzyGT/controlnet-union-promax-sdxl-1.0", | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| ), | |
| ] | |
| vae = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 | |
| ).to("cuda") | |
| pipe = StableDiffusionXLRecolorPipeline.from_pretrained( | |
| "recoilme/ColorfulXL-Lightning", | |
| torch_dtype=torch.float16, | |
| vae=vae, | |
| controlnet=controlnet, | |
| variant="fp16", | |
| ).to("cuda") | |
| pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config) | |
| pipe.load_ip_adapter( | |
| "h94/IP-Adapter", | |
| subfolder="sdxl_models", | |
| weight_name="ip-adapter_sdxl_vit-h.safetensors", | |
| image_encoder_folder="models/image_encoder", | |
| ) | |
| scale = { | |
| "up": {"block_0": [1.0, 0.0, 1.0]}, | |
| } | |
| pipe.set_ip_adapter_scale(scale) | |
| prompt = "high quality color photo, sharp, detailed, 4k, colorized, remastered" | |
| negative_prompt = "blurry, low resolution, bad quality, pixelated, black and white, b&w, grayscale, monochrome, sepia" | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ) = pipe.encode_prompt(prompt, negative_prompt, "cuda", True) | |
| def recolor_image(image): | |
| source_image = image["background"] | |
| lineart_image = lineart_preprocessor(source_image, resolution_scale=0.7)[0] | |
| for image in pipe( | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
| image=[source_image, lineart_image], | |
| ip_adapter_image=source_image, | |
| num_inference_steps=8, | |
| guidance_scale=2.0, | |
| controlnet_conditioning_scale=[1.0, 0.5], | |
| control_guidance_end=[1.0, 0.9], | |
| ): | |
| yield source_image, image | |
| image = image.convert("RGBA") | |
| source_image = source_image.convert("RGBA") | |
| enhancer = ImageEnhance.Color(image) | |
| image = enhancer.enhance(4.0) | |
| alpha = image.split()[3] | |
| alpha = alpha.point(lambda p: p * 0.20) | |
| image.putalpha(alpha) | |
| merged_image = Image.alpha_composite(source_image, image) | |
| yield source_image, merged_image | |
| def clear_result(): | |
| return gr.update(value=None) | |
| css = """ | |
| .gradio-container { | |
| width: 1024px !important; | |
| } | |
| """ | |
| title = """<h1 align="center">Diffusers Image Recolor</h1> | |
| <div align="center">Upload a grayscale image to colorize it.</div> | |
| <div align="center">This space is a PoC made for the guide <a href='https://huggingface.co/blog/OzzyGT/diffusers-recolor'>Recoloring photos with diffusers</a>.</div> | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.HTML(title) | |
| run_button = gr.Button("Generate") | |
| with gr.Row(): | |
| input_image = gr.ImageEditor( | |
| type="pil", | |
| label="Input Image", | |
| crop_size=(1024, 1024), | |
| canvas_size=(1024, 1024), | |
| layers=False, | |
| eraser=False, | |
| brush=False, | |
| sources=["upload"], | |
| image_mode="RGB", | |
| ) | |
| result = ImageSlider(interactive=False, label="Generated Image", type="pil") | |
| run_button.click( | |
| fn=clear_result, | |
| inputs=None, | |
| outputs=result, | |
| ).then( | |
| fn=recolor_image, | |
| inputs=[input_image], | |
| outputs=result, | |
| ) | |
| demo.launch(share=False) | |