Spaces:
Running
Running
from typing import Tuple | |
from PIL import Image | |
import torch | |
import spaces | |
import numpy as np | |
import random | |
import gradio as gr | |
import os | |
from diffusers.pipelines.flux.pipeline_flux_controlnet_inpaint import ( | |
FluxControlNetInpaintPipeline, | |
) | |
from diffusers.pipelines.flux.pipeline_flux_inpaint import FluxInpaintPipeline | |
from diffusers.models.controlnet_flux import FluxControlNetModel | |
from controlnet_aux import CannyDetector | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
MAX_SEED = np.iinfo(np.int32).max | |
IMAGE_SIZE = 1024 | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
base_model = "black-forest-labs/FLUX.1-dev" | |
controlnet_model = "InstantX/FLUX.1-dev-Controlnet-Canny" | |
controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=dtype) | |
pipe = FluxControlNetInpaintPipeline.from_pretrained( | |
base_model, controlnet=controlnet, torch_dtype=dtype | |
).to(device) | |
# pipe = FluxInpaintPipeline.from_pretrained(base_model, torch_dtype=dtype).to(device) | |
pipe.enable_model_cpu_offload() | |
canny = CannyDetector() | |
def resize_image_dimensions( | |
original_resolution_wh: Tuple[int, int], maximum_dimension: int = IMAGE_SIZE | |
) -> Tuple[int, int]: | |
width, height = original_resolution_wh | |
# if width <= maximum_dimension and height <= maximum_dimension: | |
# width = width - (width % 32) | |
# height = height - (height % 32) | |
# return width, height | |
if width > height: | |
scaling_factor = maximum_dimension / width | |
else: | |
scaling_factor = maximum_dimension / height | |
new_width = int(width * scaling_factor) | |
new_height = int(height * scaling_factor) | |
new_width = new_width - (new_width % 32) | |
new_height = new_height - (new_height % 32) | |
return new_width, new_height | |
# def get_system_memory(): | |
# memory = psutil.virtual_memory() | |
# memory_percent = memory.percent | |
# memory_used = memory.used / (1024.0**3) | |
# memory_total = memory.total / (1024.0**3) | |
# return { | |
# "percent": f"{memory_percent}%", | |
# "used": f"{memory_used:.3f}GB", | |
# "total": f"{memory_total:.3f}GB", | |
# } | |
# | |
def inpaint( | |
image, | |
mask, | |
prompt, | |
strength, | |
num_inference_steps, | |
guidance_scale, | |
controlnet_conditioning_scale | |
): | |
generator = torch.Generator(device=device).manual_seed(random.randint(0, MAX_SEED)) | |
canny_image = canny(image) | |
# image = image_input["background"] | |
# mask = image_input["layers"][0] | |
width, height = resize_image_dimensions(original_resolution_wh=image.size) | |
resized_image = image.resize((width, height), Image.LANCZOS) | |
resized_mask = mask.resize((width, height), Image.LANCZOS) | |
image_res = pipe( | |
prompt, | |
image=resized_image, | |
control_image=canny_image, | |
mask_image=resized_mask, | |
strength=strength, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
controlnet_conditioning_scale=controlnet_conditioning_scale | |
).images[0] | |
return image_res, canny_image, resized_image, resized_mask | |
iface = gr.Interface( | |
fn=inpaint, | |
inputs=[ | |
gr.Image(type="pil", label="Input Image"), | |
gr.Image(type="pil", label="Mask Image"), | |
gr.Textbox(label="Prompt"), | |
gr.Slider(0, 1, value=0.85, label="Strength"), | |
gr.Slider(1, 50, value=30, step=1, label="Number of Inference Steps"), | |
gr.Slider(1, 20, value=7.0, label="Guidance Scale"), | |
gr.Slider(0, 1, value=0.7, label="Controlnet conditioning") | |
# gr.Number(label="Seed", precision=0), | |
], | |
outputs=[ | |
gr.Image(type="pil", label="Output Image"), | |
gr.Image(type="pil", label="Canny Image"), | |
gr.Image(type="pil", label="Image"), | |
gr.Image(type="pil", label="Mask") | |
], | |
title="FLUX.1 Inpainting", | |
description="Inpainting using the FLUX.1 model. Upload an image and a mask, then provide a prompt to guide the inpainting process.", | |
) | |
iface.launch() | |
# with gr.Blocks() as demo: | |
# # gr.LoginButton() | |
# # with gr.Row(): | |
# # with gr.Column(): | |
# # gr.Textbox(value="Hello Memory") | |
# # with gr.Column(): | |
# # gr.JSON(get_system_memory, every=1) | |
# gr.Interface( | |
# fn=inpaint, | |
# inputs=[ | |
# gr.ImageEditor( | |
# label="Image", | |
# type="pil", | |
# sources=["upload", "webcam"], | |
# image_mode="RGB", | |
# layers=False, | |
# brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"), | |
# ), | |
# gr.Textbox(label="Prompt"), | |
# gr.Slider(0, 1, value=0.95, label="Strength"), | |
# gr.Slider(1, 100, value=50, step=1, label="Number of Inference Steps"), | |
# gr.Slider(0, 20, value=5, label="Guidance Scale"), | |
# # gr.Slider(0, 1, value=0.5, label="ControlNet Conditioning Scale"), | |
# ], | |
# outputs=gr.Image(type="pil", label="Output Image"), | |
# title="Flux Inpaint AI Model", | |
# description="Upload an image and a mask, then provide a prompt to generate an inpainted image.", | |
# ) | |
# | |
# demo.launch(height=800) | |