OzzyGT's picture
OzzyGT HF staff
reverse images
444a999
import gradio as gr
import spaces
import torch
from diffusers import AutoencoderKL, ControlNetModel, TCDScheduler
from gradio_imageslider import ImageSlider
from image_gen_aux import LineArtPreprocessor
from PIL import Image, ImageEnhance
from controlnet_union import ControlNetModel_Union
from pipeline_sdxl_recolor import StableDiffusionXLRecolorPipeline
lineart_preprocessor = LineArtPreprocessor.from_pretrained("OzzyGT/lineart").to("cuda")
controlnet = [
ControlNetModel.from_pretrained(
"OzzyGT/ControlNet-recolorXL", torch_dtype=torch.float16, variant="fp16"
),
ControlNetModel_Union.from_pretrained(
"OzzyGT/controlnet-union-promax-sdxl-1.0",
torch_dtype=torch.float16,
variant="fp16",
),
]
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to("cuda")
pipe = StableDiffusionXLRecolorPipeline.from_pretrained(
"recoilme/ColorfulXL-Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=controlnet,
variant="fp16",
).to("cuda")
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
pipe.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name="ip-adapter_sdxl_vit-h.safetensors",
image_encoder_folder="models/image_encoder",
)
scale = {
"up": {"block_0": [1.0, 0.0, 1.0]},
}
pipe.set_ip_adapter_scale(scale)
prompt = "high quality color photo, sharp, detailed, 4k, colorized, remastered"
negative_prompt = "blurry, low resolution, bad quality, pixelated, black and white, b&w, grayscale, monochrome, sepia"
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(prompt, negative_prompt, "cuda", True)
@spaces.GPU(duration=16)
def recolor_image(image):
source_image = image["background"]
lineart_image = lineart_preprocessor(source_image, resolution_scale=0.7)[0]
for image in pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
image=[source_image, lineart_image],
ip_adapter_image=source_image,
num_inference_steps=8,
guidance_scale=2.0,
controlnet_conditioning_scale=[1.0, 0.5],
control_guidance_end=[1.0, 0.9],
):
yield source_image, image
image = image.convert("RGBA")
source_image = source_image.convert("RGBA")
enhancer = ImageEnhance.Color(image)
image = enhancer.enhance(4.0)
alpha = image.split()[3]
alpha = alpha.point(lambda p: p * 0.20)
image.putalpha(alpha)
merged_image = Image.alpha_composite(source_image, image)
yield source_image, merged_image
def clear_result():
return gr.update(value=None)
css = """
.gradio-container {
width: 1024px !important;
}
"""
title = """<h1 align="center">Diffusers Image Recolor</h1>
<div align="center">Upload a grayscale image to colorize it.</div>
<div align="center">This space is a PoC made for the guide <a href='https://huggingface.co/blog/OzzyGT/diffusers-recolor'>Recoloring photos with diffusers</a>.</div>
"""
with gr.Blocks(css=css) as demo:
gr.HTML(title)
run_button = gr.Button("Generate")
with gr.Row():
input_image = gr.ImageEditor(
type="pil",
label="Input Image",
crop_size=(1024, 1024),
canvas_size=(1024, 1024),
layers=False,
eraser=False,
brush=False,
sources=["upload"],
image_mode="RGB",
)
result = ImageSlider(interactive=False, label="Generated Image", type="pil")
run_button.click(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=recolor_image,
inputs=[input_image],
outputs=result,
)
demo.launch(share=False)