Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
from diffusers import AutoencoderKL, ControlNetModel, TCDScheduler | |
from gradio_imageslider import ImageSlider | |
from image_gen_aux import LineArtPreprocessor | |
from PIL import Image, ImageEnhance | |
from controlnet_union import ControlNetModel_Union | |
from pipeline_sdxl_recolor import StableDiffusionXLRecolorPipeline | |
lineart_preprocessor = LineArtPreprocessor.from_pretrained("OzzyGT/lineart").to("cuda") | |
controlnet = [ | |
ControlNetModel.from_pretrained( | |
"OzzyGT/ControlNet-recolorXL", torch_dtype=torch.float16, variant="fp16" | |
), | |
ControlNetModel_Union.from_pretrained( | |
"OzzyGT/controlnet-union-promax-sdxl-1.0", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
), | |
] | |
vae = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 | |
).to("cuda") | |
pipe = StableDiffusionXLRecolorPipeline.from_pretrained( | |
"recoilme/ColorfulXL-Lightning", | |
torch_dtype=torch.float16, | |
vae=vae, | |
controlnet=controlnet, | |
variant="fp16", | |
).to("cuda") | |
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config) | |
pipe.load_ip_adapter( | |
"h94/IP-Adapter", | |
subfolder="sdxl_models", | |
weight_name="ip-adapter_sdxl_vit-h.safetensors", | |
image_encoder_folder="models/image_encoder", | |
) | |
scale = { | |
"up": {"block_0": [1.0, 0.0, 1.0]}, | |
} | |
pipe.set_ip_adapter_scale(scale) | |
prompt = "high quality color photo, sharp, detailed, 4k, colorized, remastered" | |
negative_prompt = "blurry, low resolution, bad quality, pixelated, black and white, b&w, grayscale, monochrome, sepia" | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = pipe.encode_prompt(prompt, negative_prompt, "cuda", True) | |
def recolor_image(image): | |
source_image = image["background"] | |
lineart_image = lineart_preprocessor(source_image, resolution_scale=0.7)[0] | |
for image in pipe( | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
image=[source_image, lineart_image], | |
ip_adapter_image=source_image, | |
num_inference_steps=8, | |
guidance_scale=2.0, | |
controlnet_conditioning_scale=[1.0, 0.5], | |
control_guidance_end=[1.0, 0.9], | |
): | |
yield source_image, image | |
image = image.convert("RGBA") | |
source_image = source_image.convert("RGBA") | |
enhancer = ImageEnhance.Color(image) | |
image = enhancer.enhance(4.0) | |
alpha = image.split()[3] | |
alpha = alpha.point(lambda p: p * 0.20) | |
image.putalpha(alpha) | |
merged_image = Image.alpha_composite(source_image, image) | |
yield source_image, merged_image | |
def clear_result(): | |
return gr.update(value=None) | |
css = """ | |
.gradio-container { | |
width: 1024px !important; | |
} | |
""" | |
title = """<h1 align="center">Diffusers Image Recolor</h1> | |
<div align="center">Upload a grayscale image to colorize it.</div> | |
<div align="center">This space is a PoC made for the guide <a href='https://huggingface.co/blog/OzzyGT/diffusers-recolor'>Recoloring photos with diffusers</a>.</div> | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(title) | |
run_button = gr.Button("Generate") | |
with gr.Row(): | |
input_image = gr.ImageEditor( | |
type="pil", | |
label="Input Image", | |
crop_size=(1024, 1024), | |
canvas_size=(1024, 1024), | |
layers=False, | |
eraser=False, | |
brush=False, | |
sources=["upload"], | |
image_mode="RGB", | |
) | |
result = ImageSlider(interactive=False, label="Generated Image", type="pil") | |
run_button.click( | |
fn=clear_result, | |
inputs=None, | |
outputs=result, | |
).then( | |
fn=recolor_image, | |
inputs=[input_image], | |
outputs=result, | |
) | |
demo.launch(share=False) | |