| import gradio as gr |
| import spaces |
| import torch |
|
|
| from diffusers import AutoencoderKL, ControlNetUnionModel, DiffusionPipeline, TCDScheduler |
|
|
|
|
| def callback_cfg_cutoff(pipeline, step_index, timestep, callback_kwargs): |
| if step_index == int(pipeline.num_timesteps * 0.2): |
| prompt_embeds = callback_kwargs["prompt_embeds"] |
| prompt_embeds = prompt_embeds[-1:] |
|
|
| add_text_embeds = callback_kwargs["add_text_embeds"] |
| add_text_embeds = add_text_embeds[-1:] |
|
|
| add_time_ids = callback_kwargs["add_time_ids"] |
| add_time_ids = add_time_ids[-1:] |
|
|
| control_image = callback_kwargs["control_image"] |
| control_image[0] = control_image[0][-1:] |
|
|
| control_type = callback_kwargs["control_type"] |
| control_type = control_type[-1:] |
|
|
| pipeline._guidance_scale = 0.0 |
| callback_kwargs["prompt_embeds"] = prompt_embeds |
| callback_kwargs["add_text_embeds"] = add_text_embeds |
| callback_kwargs["add_time_ids"] = add_time_ids |
| callback_kwargs["control_image"] = control_image |
| callback_kwargs["control_type"] = control_type |
|
|
| return callback_kwargs |
|
|
|
|
| MODELS = { |
| "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning", |
| } |
|
|
| controlnet_model = ControlNetUnionModel.from_pretrained( |
| "OzzyGT/controlnet-union-promax-sdxl-1.0", variant="fp16", torch_dtype=torch.float16 |
| ) |
| controlnet_model.to(device="cuda", dtype=torch.float16) |
| vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda") |
|
|
| pipe = DiffusionPipeline.from_pretrained( |
| "SG161222/RealVisXL_V5.0_Lightning", |
| torch_dtype=torch.float16, |
| vae=vae, |
| controlnet=controlnet_model, |
| custom_pipeline="OzzyGT/custom_sdxl_cnet_union", |
| ).to("cuda") |
|
|
| pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config) |
|
|
|
|
| @spaces.GPU(duration=24) |
| def fill_image(prompt, negative_prompt, image, model_selection, paste_back): |
| ( |
| prompt_embeds, |
| negative_prompt_embeds, |
| pooled_prompt_embeds, |
| negative_pooled_prompt_embeds, |
| ) = pipe.encode_prompt(prompt, device="cuda", negative_prompt=negative_prompt) |
|
|
| source = image["background"] |
| mask = image["layers"][0] |
|
|
| alpha_channel = mask.split()[3] |
| binary_mask = alpha_channel.point(lambda p: p > 0 and 255) |
| cnet_image = source.copy() |
| cnet_image.paste(0, (0, 0), binary_mask) |
|
|
| image = pipe( |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
| control_image=[cnet_image], |
| controlnet_conditioning_scale=[1.0], |
| control_mode=[7], |
| num_inference_steps=8, |
| guidance_scale=1.5, |
| callback_on_step_end=callback_cfg_cutoff, |
| callback_on_step_end_tensor_inputs=[ |
| "prompt_embeds", |
| "add_text_embeds", |
| "add_time_ids", |
| "control_image", |
| "control_type", |
| ], |
| ).images[0] |
|
|
| if paste_back: |
| image = image.convert("RGBA") |
| cnet_image.paste(image, (0, 0), binary_mask) |
| else: |
| cnet_image = image |
|
|
| yield source, cnet_image |
|
|
|
|
| def clear_result(): |
| return gr.update(value=None) |
|
|
|
|
| title = """<h2 align="center">Diffusers Fast Inpaint</h2> |
| <div align="center">Draw the mask over the subject you want to erase or change and write what you want to inpaint it with.</div> |
| """ |
|
|
| with gr.Blocks() as demo: |
| gr.HTML(title) |
| with gr.Row(): |
| with gr.Column(): |
| prompt = gr.Textbox( |
| label="Prompt", |
| lines=1, |
| ) |
| with gr.Column(): |
| with gr.Row(): |
| negative_prompt = gr.Textbox( |
| label="Negative Prompt", |
| lines=1, |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| run_button = gr.Button("Generate") |
|
|
| with gr.Column(): |
| paste_back = gr.Checkbox(True, label="Paste back original") |
|
|
| with gr.Row(): |
| input_image = gr.ImageMask( |
| type="pil", |
| label="Input Image", |
| crop_size=(1024, 1024), |
| canvas_size=(1024, 1024), |
| layers=False, |
| height=512, |
| ) |
|
|
| result = gr.ImageSlider( |
| interactive=False, |
| label="Generated Image", |
| ) |
|
|
| use_as_input_button = gr.Button("Use as Input Image", visible=False) |
|
|
| model_selection = gr.Dropdown(choices=list(MODELS.keys()), value="RealVisXL V5.0 Lightning", label="Model") |
|
|
| def use_output_as_input(output_image): |
| return gr.update(value=output_image[1]) |
|
|
| use_as_input_button.click(fn=use_output_as_input, inputs=[result], outputs=[input_image]) |
|
|
| run_button.click( |
| fn=clear_result, |
| inputs=None, |
| outputs=result, |
| ).then( |
| fn=lambda: gr.update(visible=False), |
| inputs=None, |
| outputs=use_as_input_button, |
| ).then( |
| fn=fill_image, |
| inputs=[prompt, negative_prompt, input_image, model_selection, paste_back], |
| outputs=result, |
| ).then( |
| fn=lambda: gr.update(visible=True), |
| inputs=None, |
| outputs=use_as_input_button, |
| ) |
|
|
| prompt.submit( |
| fn=clear_result, |
| inputs=None, |
| outputs=result, |
| ).then( |
| fn=lambda: gr.update(visible=False), |
| inputs=None, |
| outputs=use_as_input_button, |
| ).then( |
| fn=fill_image, |
| inputs=[prompt, negative_prompt, input_image, model_selection, paste_back], |
| outputs=result, |
| ).then( |
| fn=lambda: gr.update(visible=True), |
| inputs=None, |
| outputs=use_as_input_button, |
| ) |
|
|
|
|
| demo.queue(max_size=12).launch(share=False) |
|
|