|
from typing import Tuple
|
|
|
|
import requests
|
|
import random
|
|
import numpy as np
|
|
import gradio as gr
|
|
import spaces
|
|
import torch
|
|
from PIL import Image
|
|
from diffusers import FluxInpaintPipeline
|
|
from huggingface_hub import login
|
|
import os
|
|
import time
|
|
from gradio_imageslider import ImageSlider
|
|
|
|
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
|
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
|
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
|
|
import requests
|
|
from io import BytesIO
|
|
import PIL.Image
|
|
import requests
|
|
|
|
MARKDOWN = """
|
|
# FLUX.1 Inpainting with lora
|
|
"""
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max
|
|
IMAGE_SIZE = 1024
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
|
|
if HF_TOKEN: login(token=HF_TOKEN)
|
|
|
|
|
|
bfl_repo="camenduru/FLUX.1-dev-diffusers"
|
|
|
|
class calculateDuration:
|
|
def __init__(self, activity_name=""):
|
|
self.activity_name = activity_name
|
|
|
|
def __enter__(self):
|
|
self.start_time = time.time()
|
|
self.start_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.start_time))
|
|
print(f"Activity: {self.activity_name}, Start time: {self.start_time_formatted}")
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
self.end_time = time.time()
|
|
self.elapsed_time = self.end_time - self.start_time
|
|
self.end_time_formatted = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(self.end_time))
|
|
|
|
if self.activity_name:
|
|
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
|
|
else:
|
|
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
|
|
|
print(f"Activity: {self.activity_name}, End time: {self.start_time_formatted}")
|
|
|
|
|
|
def remove_background(image: Image.Image, threshold: int = 50) -> Image.Image:
|
|
image = image.convert("RGBA")
|
|
data = image.getdata()
|
|
new_data = []
|
|
for item in data:
|
|
avg = sum(item[:3]) / 3
|
|
if avg < threshold:
|
|
new_data.append((0, 0, 0, 0))
|
|
else:
|
|
new_data.append(item)
|
|
|
|
image.putdata(new_data)
|
|
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipe = FluxInpaintPipeline.from_pretrained(bfl_repo, torch_dtype=torch.bfloat16).to(DEVICE)
|
|
|
|
|
|
def resize_image_dimensions(
|
|
original_resolution_wh: Tuple[int, int],
|
|
maximum_dimension: int = IMAGE_SIZE
|
|
) -> Tuple[int, int]:
|
|
width, height = original_resolution_wh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if width > height:
|
|
scaling_factor = maximum_dimension / width
|
|
else:
|
|
scaling_factor = maximum_dimension / height
|
|
|
|
new_width = int(width * scaling_factor)
|
|
new_height = int(height * scaling_factor)
|
|
|
|
new_width = new_width - (new_width % 32)
|
|
new_height = new_height - (new_height % 32)
|
|
|
|
return new_width, new_height
|
|
|
|
|
|
@spaces.GPU(duration=100)
|
|
def process(
|
|
input_image_editor: dict,
|
|
image_url: str,
|
|
mask_url: str,
|
|
blur_mask: bool,
|
|
blur_factor: int,
|
|
lora_path: str,
|
|
lora_weights: str,
|
|
lora_scale: float,
|
|
trigger_word: str,
|
|
input_text: str,
|
|
seed_slicer: int,
|
|
randomize_seed_checkbox: bool,
|
|
strength_slider: float,
|
|
num_inference_steps_slider: int,
|
|
progress=gr.Progress(track_tqdm=True)
|
|
):
|
|
if not input_text:
|
|
gr.Info("Please enter a text prompt.")
|
|
return None, None
|
|
|
|
|
|
image = input_image_editor['background']
|
|
mask = input_image_editor['layers'][0]
|
|
|
|
if image_url:
|
|
print("start to fetch image from url", image_url)
|
|
response = requests.get(image_url)
|
|
response.raise_for_status()
|
|
image = PIL.Image.open(BytesIO(response.content))
|
|
print("fetch image success")
|
|
|
|
if mask_url:
|
|
print("start to fetch mask from url", mask_url)
|
|
response = requests.get(mask_url)
|
|
response.raise_for_status()
|
|
mask = PIL.Image.open(BytesIO(response.content))
|
|
print("fetch mask success")
|
|
|
|
if not image:
|
|
gr.Info("Please upload an image.")
|
|
return None, None
|
|
|
|
if not mask:
|
|
gr.Info("Please draw a mask on the image.")
|
|
return None, None
|
|
if blur_mask:
|
|
mask = pipe.mask_processor.blur(mask, blur_factor=blur_factor)
|
|
|
|
with calculateDuration("resize image"):
|
|
width, height = resize_image_dimensions(original_resolution_wh=image.size)
|
|
resized_image = image.resize((width, height), Image.LANCZOS)
|
|
resized_mask = mask.resize((width, height), Image.LANCZOS)
|
|
|
|
with calculateDuration("load lora"):
|
|
print(lora_path, lora_weights)
|
|
pipe.load_lora_weights(lora_path, weight_name=lora_weights)
|
|
|
|
if randomize_seed_checkbox:
|
|
seed_slicer = random.randint(0, MAX_SEED)
|
|
generator = torch.Generator().manual_seed(seed_slicer)
|
|
|
|
with calculateDuration("run pipe"):
|
|
print(input_text, width, height, strength_slider, num_inference_steps_slider, lora_scale)
|
|
result = pipe(
|
|
prompt=f"{input_text} {trigger_word}",
|
|
image=resized_image,
|
|
mask_image=resized_mask,
|
|
width=width,
|
|
height=height,
|
|
strength=strength_slider,
|
|
generator=generator,
|
|
num_inference_steps=num_inference_steps_slider,
|
|
max_sequence_length=256,
|
|
joint_attention_kwargs={"scale": lora_scale},
|
|
).images[0]
|
|
|
|
return [resized_image, result], resized_mask
|
|
|
|
|
|
with gr.Blocks() as demo:
|
|
gr.Markdown(MARKDOWN)
|
|
with gr.Row():
|
|
with gr.Column():
|
|
input_image_editor_component = gr.ImageEditor(
|
|
label='Image',
|
|
type='pil',
|
|
sources=["upload", "webcam"],
|
|
image_mode='RGB',
|
|
layers=False,
|
|
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"))
|
|
|
|
image_url = gr.Textbox(
|
|
label="image url",
|
|
show_label=True,
|
|
max_lines=1,
|
|
placeholder="Enter your image url (Optional)",
|
|
)
|
|
mask_url = gr.Textbox(
|
|
label="Mask image url",
|
|
show_label=True,
|
|
max_lines=1,
|
|
placeholder="Enter your mask image url (Optional)",
|
|
)
|
|
|
|
with gr.Accordion("Prompt Settings", open=True):
|
|
|
|
input_text_component = gr.Textbox(
|
|
label="Inpaint prompt",
|
|
show_label=True,
|
|
max_lines=1,
|
|
placeholder="Enter your prompt",
|
|
)
|
|
trigger_word = gr.Textbox(
|
|
label="Lora trigger word",
|
|
show_label=True,
|
|
max_lines=1,
|
|
placeholder="Enter your lora trigger word here",
|
|
value="a photo of TOK"
|
|
|
|
)
|
|
|
|
submit_button_component = gr.Button(
|
|
value='Submit', variant='primary', scale=0)
|
|
|
|
with gr.Accordion("Lora Settings", open=True):
|
|
lora_path = gr.Textbox(
|
|
label="Lora model path",
|
|
show_label=True,
|
|
max_lines=1,
|
|
placeholder="Enter your model path",
|
|
info="Currently, only LoRA hosted on Hugging Face'model can be loaded properly.",
|
|
value="XLabs-AI/flux-RealismLora"
|
|
)
|
|
lora_weights = gr.Textbox(
|
|
label="Lora weights",
|
|
show_label=True,
|
|
max_lines=1,
|
|
placeholder="Enter your lora weights name",
|
|
value="lora.safetensors"
|
|
)
|
|
lora_scale = gr.Slider(
|
|
label="Lora scale",
|
|
show_label=True,
|
|
minimum=0,
|
|
maximum=1,
|
|
step=0.1,
|
|
value=0.9,
|
|
)
|
|
|
|
with gr.Accordion("Advanced Settings", open=True):
|
|
|
|
|
|
seed_slicer_component = gr.Slider(
|
|
label="Seed",
|
|
minimum=0,
|
|
maximum=MAX_SEED,
|
|
step=1,
|
|
value=42,
|
|
)
|
|
|
|
randomize_seed_checkbox_component = gr.Checkbox(
|
|
label="Randomize seed", value=True)
|
|
|
|
blur_mask = gr.Checkbox(
|
|
label="if blur mask", value=False)
|
|
blur_factor = gr.Slider(
|
|
label="blur factor",
|
|
minimum=0,
|
|
maximum=50,
|
|
step=1,
|
|
value=33,
|
|
)
|
|
with gr.Row():
|
|
strength_slider_component = gr.Slider(
|
|
label="Strength",
|
|
info="Indicates extent to transform the reference `image`. "
|
|
"Must be between 0 and 1. `image` is used as a starting "
|
|
"point and more noise is added the higher the `strength`.",
|
|
minimum=0,
|
|
maximum=1,
|
|
step=0.01,
|
|
value=0.85,
|
|
)
|
|
|
|
num_inference_steps_slider_component = gr.Slider(
|
|
label="Number of inference steps",
|
|
info="The number of denoising steps. More denoising steps "
|
|
"usually lead to a higher quality image at the",
|
|
minimum=1,
|
|
maximum=50,
|
|
step=1,
|
|
value=28,
|
|
)
|
|
with gr.Column():
|
|
output_image_component = ImageSlider(label="Generate image", type="pil", slider_color="pink")
|
|
|
|
with gr.Accordion("Debug", open=False):
|
|
output_mask_component = gr.Image(
|
|
type='pil', image_mode='RGB', label='Input mask', format="png")
|
|
|
|
submit_button_component.click(
|
|
fn=process,
|
|
inputs=[
|
|
input_image_editor_component,
|
|
image_url,
|
|
mask_url,
|
|
blur_mask,
|
|
blur_factor,
|
|
lora_path,
|
|
lora_weights,
|
|
lora_scale,
|
|
trigger_word,
|
|
input_text_component,
|
|
seed_slicer_component,
|
|
randomize_seed_checkbox_component,
|
|
strength_slider_component,
|
|
num_inference_steps_slider_component
|
|
],
|
|
outputs=[
|
|
output_image_component,
|
|
output_mask_component
|
|
]
|
|
)
|
|
|
|
demo.launch(debug=False, show_error=True)
|
|
|